ContainerImage.Pinniped/pkg/oidcclient/login_test.go
Matt Moyer 1e32530d7b
Fix broken TTY after manual auth code prompt.
This may be a temporary fix. It switches the manual auth code prompt to use `promptForValue()` instead of `promptForSecret()`. The `promptForSecret()` function no longer supports cancellation (the v0.9.2 behavior) and the method of cancelling in `promptForValue()` is now based on running the blocking read in a background goroutine, which is allowed to block forever or leak (which is not important for our CLI use case).

This means that the authorization code is now visible in the user's terminal, but this is really not a big deal because of PKCE and the limited lifetime of an auth code.

The main goroutine now correctly waits for the "manual prompt" goroutine to clean up, which now includes printing the extra newline that would normally have been entered by the user in the manual flow.

The text of the manual login prompt is updated to be more concise and less scary (don't use the word "fail").

Signed-off-by: Matt Moyer <moyerm@vmware.com>
2021-07-30 12:45:44 -05:00

1930 lines
77 KiB
Go

// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package oidcclient
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"syscall"
"testing"
"time"
"github.com/go-logr/stdr"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/klog/v2"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/roundtripper"
"go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/testlogger"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
"go.pinniped.dev/pkg/oidcclient/state"
)
// mockSessionCache exists to avoid an import cycle if we generate mocks into another package.
type mockSessionCache struct {
t *testing.T
getReturnsToken *oidctypes.Token
sawGetKeys []SessionCacheKey
sawPutKeys []SessionCacheKey
sawPutTokens []*oidctypes.Token
}
func (m *mockSessionCache) GetToken(key SessionCacheKey) *oidctypes.Token {
m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID)
m.sawGetKeys = append(m.sawGetKeys, key)
return m.getReturnsToken
}
func (m *mockSessionCache) PutToken(key SessionCacheKey, token *oidctypes.Token) {
m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token)
m.sawPutKeys = append(m.sawPutKeys, key)
m.sawPutTokens = append(m.sawPutTokens, token)
}
func TestLogin(t *testing.T) { // nolint:gocyclo
time1 := time.Date(2035, 10, 12, 13, 14, 15, 16, time.UTC)
time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
testToken := oidctypes.Token{
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
}
testExchangedToken := oidctypes.Token{
IDToken: &oidctypes.IDToken{Token: "test-id-token-with-requested-audience", Expiry: metav1.NewTime(time1.Add(3 * time.Minute))},
}
// Start a test server that returns 500 errors
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "some discovery error", http.StatusInternalServerError)
}))
t.Cleanup(errorServer.Close)
// Start a test server that returns discovery data with a broken response_modes_supported value.
brokenResponseModeMux := http.NewServeMux()
brokenResponseModeServer := httptest.NewServer(brokenResponseModeMux)
brokenResponseModeMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
ResponseModesSupported string `json:"response_modes_supported"` // Wrong type (should be []string).
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: brokenResponseModeServer.URL,
ResponseModesSupported: "invalid",
})
})
t.Cleanup(brokenResponseModeServer.Close)
// Start a test server that returns discovery data with a broken token URL
brokenTokenURLMux := http.NewServeMux()
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
brokenTokenURLMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: brokenTokenURLServer.URL,
AuthURL: brokenTokenURLServer.URL + "/authorize",
TokenURL: "%",
JWKSURL: brokenTokenURLServer.URL + "/keys",
})
})
t.Cleanup(brokenTokenURLServer.Close)
discoveryHandler := func(server *httptest.Server, responseModes []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
w.Header().Set("content-type", "application/json")
_ = json.NewEncoder(w).Encode(&struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
}{
Issuer: server.URL,
AuthURL: server.URL + "/authorize",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/keys",
ResponseModesSupported: responseModes,
})
}
}
tokenHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
IssuedTokenType string `json:"issued_token_type,omitempty"`
}
switch r.Form.Get("grant_type") {
case "refresh_token":
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
response.AccessToken = testToken.AccessToken.Token
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
response.RefreshToken = testToken.RefreshToken.Token
response.IDToken = testToken.IDToken.Token
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" {
response.IDToken = "not a valid JWT"
} else if r.Form.Get("refresh_token") != "test-refresh-token" {
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest)
return
}
case "urn:ietf:params:oauth:grant-type:token-exchange":
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "bad client_id", http.StatusBadRequest)
return
}
switch r.Form.Get("audience") {
case "test-audience-produce-invalid-http-response":
http.Redirect(w, r, "%", http.StatusTemporaryRedirect)
return
case "test-audience-produce-http-400":
http.Error(w, "some server error", http.StatusBadRequest)
return
case "test-audience-produce-invalid-content-type":
w.Header().Set("content-type", "invalid/invalid;=")
return
case "test-audience-produce-wrong-content-type":
w.Header().Set("content-type", "invalid")
return
case "test-audience-produce-invalid-json":
w.Header().Set("content-type", "application/json;charset=UTF-8")
_, _ = w.Write([]byte(`{`))
return
case "test-audience-produce-invalid-tokentype":
response.TokenType = "invalid"
case "test-audience-produce-invalid-issuedtokentype":
response.TokenType = "N_A"
response.IssuedTokenType = "invalid"
case "test-audience-produce-invalid-jwt":
response.TokenType = "N_A"
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
response.AccessToken = "not-a-valid-jwt"
default:
response.TokenType = "N_A"
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
response.AccessToken = testExchangedToken.IDToken.Token
}
default:
http.Error(w, fmt.Sprintf("invalid grant_type %q", r.Form.Get("grant_type")), http.StatusBadRequest)
return
}
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}
// Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(successServer, nil))
providerMux.HandleFunc("/token", tokenHandler)
// Start a test server that returns a real discovery document and answers refresh requests, _and_ supports form_mode=post.
formPostProviderMux := http.NewServeMux()
formPostSuccessServer := httptest.NewServer(formPostProviderMux)
t.Cleanup(formPostSuccessServer.Close)
formPostProviderMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(formPostSuccessServer, []string{"query", "form_post"}))
formPostProviderMux.HandleFunc("/token", tokenHandler)
defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam
// Call the handler function from the test server to calculate the response.
handler, _ := providerMux.Handler(req)
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
return recorder.Result(), nil
}
defaultLDAPTestOpts := func(t *testing.T, h *handlerState, authResponse *http.Response, authError error) error { // nolint:unparam
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))
require.NoError(t, WithClient(&http.Client{
Transport: roundtripper.Func(func(req *http.Request) (*http.Response, error) {
switch req.URL.Scheme + "://" + req.URL.Host + req.URL.Path {
case "http://" + successServer.Listener.Addr().String() + "/.well-known/openid-configuration":
return defaultDiscoveryResponse(req)
case "http://" + successServer.Listener.Addr().String() + "/authorize":
return authResponse, authError
default:
require.FailNow(t, fmt.Sprintf("saw unexpected http call from the CLI: %s", req.URL.String()))
return nil, nil
}
}),
})(h))
return nil
}
tests := []struct {
name string
opt func(t *testing.T) Option
issuer string
clientID string
wantErr string
wantToken *oidctypes.Token
wantLogs []string
}{
{
name: "option error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return fmt.Errorf("some option error")
}
},
wantErr: "some option error",
},
{
name: "error generating state",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "", fmt.Errorf("some error generating state") }
return nil
}
},
wantErr: "some error generating state",
},
{
name: "error generating nonce",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateNonce = func() (nonce.Nonce, error) { return "", fmt.Errorf("some error generating nonce") }
return nil
}
},
wantErr: "some error generating nonce",
},
{
name: "error generating PKCE",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generatePKCE = func() (pkce.Code, error) { return "", fmt.Errorf("some error generating PKCE") }
return nil
}
},
wantErr: "some error generating PKCE",
},
{
name: "session cache hit but token expired",
issuer: "test-issuer",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity
},
}}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
return WithSessionCache(cache)(h)
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"test-issuer\""},
wantErr: `could not perform OIDC discovery for "test-issuer": Get "test-issuer/.well-known/openid-configuration": unsupported protocol scheme ""`,
},
{
name: "session cache hit with valid token",
issuer: "test-issuer",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
return WithSessionCache(cache)(h)
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\""},
wantToken: &testToken,
},
{
name: "discovery failure due to 500 error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil }
},
issuer: errorServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""},
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
{
name: "discovery failure due to invalid response_modes_supported",
opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil }
},
issuer: brokenResponseModeServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenResponseModeServer.URL + "\""},
wantErr: fmt.Sprintf("could not decode response_modes_supported in OIDC discovery from %q: json: cannot unmarshal string into Go struct field .response_modes_supported of type []string", brokenResponseModeServer.URL),
},
{
name: "session cache hit with refreshable token",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(&testToken, nil)
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Len(t, cache.sawPutTokens, 1)
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
})
h.cache = cache
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"",
"\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""},
wantToken: &testToken,
},
{
name: "session cache hit but refresh returns invalid token",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(nil, fmt.Errorf("some validation error"))
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
}}
t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys)
require.Empty(t, cache.sawPutTokens)
})
h.cache = cache
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"",
"\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""},
wantErr: "some validation error",
},
{
name: "session cache hit but refresh fails",
issuer: successServer.URL,
clientID: "not-the-test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys)
require.Empty(t, cache.sawPutTokens)
})
h.cache = cache
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(int) bool { return false }
return nil
}
},
wantLogs: []string{
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
`"level"=4 "msg"="Pinniped: Refreshing cached token."`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
// Expect this to fall through to the authorization code flow, so it fails here.
wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
},
{
name: "listen failure and non-tty stdin",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.listen = func(net string, addr string) (net.Listener, error) {
assert.Equal(t, "tcp", net)
assert.Equal(t, "localhost:0", addr)
return nil, fmt.Errorf("some listen error")
}
h.isTTY = func(fd int) bool {
assert.Equal(t, fd, syscall.Stdin)
return false
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
},
{
name: "listening disabled and manual prompt fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
require.NoError(t, WithSkipListen()(h))
h.isTTY = func(fd int) bool { return true }
h.openURL = func(authorizeURL string) error {
parsed, err := url.Parse(authorizeURL)
require.NoError(t, err)
require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error")
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open browser" "error"="some browser open error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "listen success and manual prompt succeeds",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(fd int) bool { return true }
h.openURL = func(authorizeURL string) error {
parsed, err := url.Parse(authorizeURL)
require.NoError(t, err)
require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "timeout waiting for callback",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
ctx, cancel := context.WithCancel(h.ctx)
h.ctx = ctx
h.openURL = func(_ string) error {
cancel()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "timed out waiting for token callback: context canceled",
},
{
name: "callback returns error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.openURL = func(_ string) error {
go func() {
h.callbacks <- callbackResult{err: fmt.Errorf("some callback error")}
}()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error handling callback: some callback error",
},
{
name: "callback returns success",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err)
actualParams := parsedActualURL.Query()
require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
}, actualParams)
parsedActualURL.RawQuery = ""
require.Equal(t, successServer.URL+"/authorize", parsedActualURL.String())
go func() {
h.callbacks <- callbackResult{token: &testToken}
}()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken,
},
{
name: "callback returns success with request_mode=form_post",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: formPostSuccessServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err)
actualParams := parsedActualURL.Query()
require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"response_mode": []string{"form_post"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
}, actualParams)
parsedActualURL.RawQuery = ""
require.Equal(t, formPostSuccessServer.URL+"/authorize", parsedActualURL.String())
go func() {
h.callbacks <- callbackResult{token: &testToken}
}()
return nil
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
wantToken: &testToken,
},
{
name: "upstream name and type are included in authorize request if upstream name is provided",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "oidc")(h))
h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err)
actualParams := parsedActualURL.Query()
require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"oidc"},
}, actualParams)
parsedActualURL.RawQuery = ""
require.Equal(t, successServer.URL+"/authorize", parsedActualURL.String())
go func() {
h.callbacks <- callbackResult{token: &testToken}
}()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken,
},
{
name: "ldap login when prompting for username returns an error",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error")
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error prompting for username: some prompt error",
},
{
name: "ldap login when prompting for password returns an error",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error prompting for password: some prompt error",
},
{
name: "ldap login when there is a problem with parsing the authorize URL",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
require.NoError(t, WithClient(&http.Client{
Transport: roundtripper.Func(func(req *http.Request) (*http.Response, error) {
switch req.URL.Scheme + "://" + req.URL.Host + req.URL.Path {
case "http://" + successServer.Listener.Addr().String() + "/.well-known/openid-configuration":
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
}
jsonResponseBody, err := json.Marshal(&providerJSON{
Issuer: successServer.URL,
AuthURL: "%", // this is not a legal URL!
TokenURL: successServer.URL + "/token",
JWKSURL: successServer.URL + "/keys",
})
require.NoError(t, err)
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"content-type": []string{"application/json"}},
Body: ioutil.NopCloser(strings.NewReader(string(jsonResponseBody))),
}, nil
default:
require.FailNow(t, fmt.Sprintf("saw unexpected http call from the CLI: %s", req.URL.String()))
return nil, nil
}
}),
})(h))
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `could not build authorize request: parse "%?access_type=offline&client_id=test-client-id&code_challenge=VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code&scope=test-scope&state=test-state": invalid URL escape "%"`,
},
{
name: "ldap login when there is an error calling the authorization endpoint",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return defaultLDAPTestOpts(t, h, nil, errors.New("some error fetching authorize endpoint"))
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `authorization response error: Get "http://` + successServer.Listener.Addr().String() +
`/authorize?access_type=offline&client_id=test-client-id&code_challenge=VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code&scope=test-scope&state=test-state": some error fetching authorize endpoint`,
},
{
name: "ldap login when the OIDC provider authorization endpoint returns something other than a 302 redirect",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return defaultLDAPTestOpts(t, h, &http.Response{StatusCode: http.StatusBadGateway, Status: "502 Bad Gateway"}, nil)
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `error getting authorization: expected to be redirected, but response status was 502 Bad Gateway`,
},
{
name: "ldap login when the OIDC provider authorization endpoint redirect has an error and error description",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return defaultLDAPTestOpts(t, h, &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{
"http://127.0.0.1:0/callback?error=access_denied&error_description=optional-error-description&state=test-state",
}},
}, nil)
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `login failed with code "access_denied": optional-error-description`,
},
{
name: "ldap login when the OIDC provider authorization endpoint redirects us to a different server",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return defaultLDAPTestOpts(t, h, &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{
"http://other-server.example.com/callback?code=foo&state=test-state",
}},
}, nil)
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `error getting authorization: redirected to the wrong location: http://other-server.example.com/callback?code=foo&state=test-state`,
},
{
name: "ldap login when the OIDC provider authorization endpoint redirect has an error but no error description",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return defaultLDAPTestOpts(t, h, &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{
"http://127.0.0.1:0/callback?error=access_denied&state=test-state",
}},
}, nil)
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `login failed with code "access_denied"`,
},
{
name: "ldap login when the OIDC provider authorization endpoint redirect has the wrong state value",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return defaultLDAPTestOpts(t, h, &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{"http://127.0.0.1:0/callback?code=foo&state=wrong-state"}},
}, nil)
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `missing or invalid state parameter in authorization response: http://127.0.0.1:0/callback?code=foo&state=wrong-state`,
},
{
name: "ldap login when there is an error exchanging the authcode or validating the tokens",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
fakeAuthCode := "test-authcode-value"
_ = defaultLDAPTestOpts(t, h, &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{
fmt.Sprintf("http://127.0.0.1:0/callback?code=%s&state=test-state", fakeAuthCode),
}},
}, nil)
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(
gomock.Any(), fakeAuthCode, pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), "http://127.0.0.1:0/callback").
Return(nil, errors.New("some authcode exchange or token validation error"))
return mock
}
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error during authorization code exchange: some authcode exchange or token validation error",
},
{
name: "successful ldap login with prompts for username and password",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
fakeAuthCode := "test-authcode-value"
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(
gomock.Any(), fakeAuthCode, pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), "http://127.0.0.1:0/callback").
Return(&testToken, nil)
return mock
}
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.getEnv = func(_ string) string {
return "" // asking for any env var returns empty as if it were unset
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))
discoveryRequestWasMade := false
authorizeRequestWasMade := false
t.Cleanup(func() {
require.True(t, discoveryRequestWasMade, "should have made an discovery request")
require.True(t, authorizeRequestWasMade, "should have made an authorize request")
})
require.NoError(t, WithClient(&http.Client{
Transport: roundtripper.Func(func(req *http.Request) (*http.Response, error) {
switch req.URL.Scheme + "://" + req.URL.Host + req.URL.Path {
case "http://" + successServer.Listener.Addr().String() + "/.well-known/openid-configuration":
discoveryRequestWasMade = true
return defaultDiscoveryResponse(req)
case "http://" + successServer.Listener.Addr().String() + "/authorize":
authorizeRequestWasMade = true
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://127.0.0.1:0/callback"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"ldap"},
}, req.URL.Query())
return &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{
fmt.Sprintf("http://127.0.0.1:0/callback?code=%s&state=test-state", fakeAuthCode),
}},
}, nil
default:
// Note that "/token" requests should not be made. They are mocked by mocking calls to ExchangeAuthcodeAndValidateTokens().
require.FailNow(t, fmt.Sprintf("saw unexpected http call from the CLI: %s", req.URL.String()))
return nil, nil
}
}),
})(h))
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken,
},
{
name: "successful ldap login with env vars for username and password",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
fakeAuthCode := "test-authcode-value"
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(
gomock.Any(), fakeAuthCode, pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), "http://127.0.0.1:0/callback").
Return(&testToken, nil)
return mock
}
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.getEnv = func(key string) string {
switch key {
case "PINNIPED_USERNAME":
return "some-upstream-username"
case "PINNIPED_PASSWORD":
return "some-upstream-password"
default:
return "" // all other env vars are treated as if they are unset
}
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))
discoveryRequestWasMade := false
authorizeRequestWasMade := false
t.Cleanup(func() {
require.True(t, discoveryRequestWasMade, "should have made an discovery request")
require.True(t, authorizeRequestWasMade, "should have made an authorize request")
})
require.NoError(t, WithClient(&http.Client{
Transport: roundtripper.Func(func(req *http.Request) (*http.Response, error) {
switch req.URL.Scheme + "://" + req.URL.Host + req.URL.Path {
case "http://" + successServer.Listener.Addr().String() + "/.well-known/openid-configuration":
discoveryRequestWasMade = true
return defaultDiscoveryResponse(req)
case "http://" + successServer.Listener.Addr().String() + "/authorize":
authorizeRequestWasMade = true
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://127.0.0.1:0/callback"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"ldap"},
}, req.URL.Query())
return &http.Response{
StatusCode: http.StatusFound,
Header: http.Header{"Location": []string{
fmt.Sprintf("http://127.0.0.1:0/callback?code=%s&state=test-state", fakeAuthCode),
}},
}, nil
default:
// Note that "/token" requests should not be made. They are mocked by mocking calls to ExchangeAuthcodeAndValidateTokens().
require.FailNow(t, fmt.Sprintf("saw unexpected http call from the CLI: %s", req.URL.String()))
return nil, nil
}
}),
})(h))
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"",
"\"level\"=4 \"msg\"=\"Pinniped: Read username from environment variable\" \"name\"=\"PINNIPED_USERNAME\"",
"\"level\"=4 \"msg\"=\"Pinniped: Read password from environment variable\" \"name\"=\"PINNIPED_PASSWORD\"",
},
wantToken: &testToken,
},
{
name: "with requested audience, session cache hit with valid token, but discovery fails",
clientID: "test-client-id",
issuer: errorServer.URL,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: errorServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("cluster-1234")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"cluster-1234\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""},
wantErr: fmt.Sprintf("failed to exchange token: could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
{
name: "with requested audience, session cache hit with valid token, but token URL is invalid",
issuer: brokenTokenURLServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: brokenTokenURLServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("cluster-1234")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"cluster-1234\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenTokenURLServer.URL + "\""},
wantErr: `failed to exchange token: could not build RFC8693 request: parse "%": invalid URL escape "%"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request fails",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-http-response")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-invalid-http-response\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: fmt.Sprintf(`failed to exchange token: Post "%s/token": failed to parse Location header "%%": parse "%%": invalid URL escape "%%"`, successServer.URL),
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns non-200",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-http-400")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-http-400\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: unexpected HTTP response status 400`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid content-type header",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-content-type")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-invalid-content-type\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: failed to decode content-type header: mime: invalid media parameter`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns wrong content-type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-wrong-content-type")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-wrong-content-type\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: unexpected HTTP response content type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JSON",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-json")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-invalid-json\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: failed to decode response: unexpected EOF`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid token_type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-tokentype")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-invalid-tokentype\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: got unexpected token_type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid issued_token_type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-issuedtokentype")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-invalid-issuedtokentype\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: got unexpected issued_token_type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JWT",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-jwt")(h))
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience-produce-invalid-jwt\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `failed to exchange token: received invalid JWT: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts`,
},
{
name: "with requested audience, session cache hit with valid token, and token exchange request succeeds",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience")(h))
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
require.Equal(t, "test-audience", audience)
require.Equal(t, "test-id-token-with-requested-audience", token)
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
}
return nil
}
},
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Found unexpired cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testExchangedToken,
},
{
name: "with requested audience, session cache hit with valid refresh token, and token exchange request succeeds",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Len(t, cache.sawPutTokens, 1)
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
})
h.cache = cache
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(&testToken, nil)
return mock
}
require.NoError(t, WithRequestAudience("test-audience")(h))
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
require.Equal(t, "test-audience", audience)
require.Equal(t, "test-id-token-with-requested-audience", token)
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
}
return nil
}
},
wantLogs: []string{
"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"",
"\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\"",
"\"level\"=4 \"msg\"=\"Pinniped: Performing RFC8693 token exchange\" \"requestedAudience\"=\"test-audience\"",
},
wantToken: &testExchangedToken,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
testLogger := testlogger.New(t)
klog.SetLogger(testLogger)
stdr.SetVerbosity(debugLogLevel) // set stdr's global log level to debug so the test logger will send output.
tok, err := Login(tt.issuer, tt.clientID,
WithContext(context.Background()),
WithListenPort(0),
WithScopes([]string{"test-scope"}),
WithSkipBrowserOpen(),
tt.opt(t),
WithLogger(testLogger),
)
testLogger.Expect(tt.wantLogs)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok)
return
}
require.NoError(t, err)
if tt.wantToken == nil {
require.Nil(t, tok)
return
}
require.NotNil(t, tok)
if want := tt.wantToken.AccessToken; want != nil {
require.NotNil(t, tok.AccessToken)
require.Equal(t, want.Token, tok.AccessToken.Token)
require.Equal(t, want.Type, tok.AccessToken.Type)
testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.AccessToken)
}
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
if want := tt.wantToken.IDToken; want != nil {
require.NotNil(t, tok.IDToken)
require.Equal(t, want.Token, tok.IDToken.Token)
testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.IDToken)
}
})
}
}
func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
tests := []struct {
name string
opt func(t *testing.T) Option
wantCallback *callbackResult
}{
{
name: "no stdin available",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool {
require.Equal(t, syscall.Stdin, fd)
return false
}
h.useFormPost = true
return nil
}
},
},
{
name: "no form_post mode available",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = false
return nil
}
},
},
{
name: "prompt fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
wantCallback: &callbackResult{
err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"),
},
},
{
name: "redeeming code fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(nil, fmt.Errorf("some exchange error"))
return mock
}
return nil
}
},
wantCallback: &callbackResult{
err: fmt.Errorf("some exchange error"),
},
},
{
name: "success",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
wantCallback: &callbackResult{
token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var buf bytes.Buffer
h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
require.Equal(t,
"Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
buf.String(),
)
if tt.wantCallback != nil {
select {
case <-time.After(1 * time.Second):
require.Fail(t, "timed out waiting to receive from callbacks channel")
case result := <-h.callbacks:
require.Equal(t, *tt.wantCallback, result)
}
}
})
}
}
func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
withFormPostMode := func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
return nil
}
}
tests := []struct {
name string
method string
query string
body []byte
contentType string
opt func(t *testing.T) Option
wantErr string
wantHTTPStatus int
}{
{
name: "wrong method",
method: "POST",
query: "",
wantErr: "wanted GET",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "wrong method for form_post",
method: "GET",
query: "",
opt: withFormPostMode,
wantErr: "wanted POST",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "invalid form for form_post",
method: "POST",
query: "",
contentType: "application/x-www-form-urlencoded",
body: []byte(`%`),
opt: withFormPostMode,
wantErr: `invalid form: invalid URL escape "%"`,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid state",
query: "state=invalid",
wantErr: "missing or invalid state parameter",
wantHTTPStatus: http.StatusForbidden,
},
{
name: "error code from provider",
query: "state=test-state&error=some_error",
wantErr: `login failed with code "some_error"`,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "error code with a description from provider",
query: "state=test-state&error=some_error&error_description=optional%20error%20description",
wantErr: `login failed with code "some_error": optional error description`,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid code",
query: "state=test-state&code=invalid",
wantErr: "could not complete code exchange: some exchange error",
wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(nil, fmt.Errorf("some exchange error"))
return mock
}
return nil
}
},
},
{
name: "valid",
query: "state=test-state&code=valid",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
},
{
name: "valid form_post",
method: http.MethodPost,
contentType: "application/x-www-form-urlencoded",
body: []byte(`state=test-state&code=valid`),
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
resp := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body))
require.NoError(t, err)
req.URL.RawQuery = tt.query
if tt.method != "" {
req.Method = tt.method
}
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
err = h.handleAuthCodeCallback(resp, req)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
if tt.wantHTTPStatus != 0 {
rec := httptest.NewRecorder()
err.(httperr.Responder).Respond(rec)
require.Equal(t, tt.wantHTTPStatus, rec.Code)
}
} else {
require.NoError(t, err)
}
select {
case <-time.After(1 * time.Second):
require.Fail(t, "timed out waiting to receive from callbacks channel")
case result := <-h.callbacks:
if tt.wantErr != "" {
require.EqualError(t, result.err, tt.wantErr)
return
}
require.NoError(t, result.err)
require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken.Token, "test-id-token")
}
})
}
}
func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
t.Helper()
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
}
// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
type hasAccessTokenMatcher struct{ expected string }
func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
return arg.(*oauth2.Token).AccessToken == m.expected
}
func (m hasAccessTokenMatcher) Got(got interface{}) string {
return got.(*oauth2.Token).AccessToken
}
func (m hasAccessTokenMatcher) String() string {
return m.expected
}
func HasAccessToken(expected string) gomock.Matcher {
return hasAccessTokenMatcher{expected: expected}
}