Refactor oidcclient.Login to use new upstreamoidc package.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-11-30 17:14:57 -06:00
parent 4b60c922ef
commit b272b3f331
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
3 changed files with 98 additions and 197 deletions

View File

@ -20,6 +20,10 @@ import (
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
return &ProviderConfig{Config: config, Provider: provider}
}
// ProviderConfig holds the active configuration of an upstream OIDC provider. // ProviderConfig holds the active configuration of an upstream OIDC provider.
type ProviderConfig struct { type ProviderConfig struct {
Name string Name string
@ -31,9 +35,6 @@ type ProviderConfig struct {
} }
} }
// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI.
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
func (p *ProviderConfig) GetName() string { func (p *ProviderConfig) GetName() string {
return p.Name return p.Name
} }

View File

@ -16,10 +16,11 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/pkg/browser" "github.com/pkg/browser"
"golang.org/x/oauth2" "golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/upstreamoidc"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
@ -52,18 +53,18 @@ type handlerState struct {
callbackPath string callbackPath string
// Generated parameters of a login flow. // Generated parameters of a login flow.
idTokenVerifier *oidc.IDTokenVerifier provider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
state state.State state state.State
nonce nonce.Nonce nonce nonce.Nonce
pkce pkce.Code pkce pkce.Code
// External calls for things. // External calls for things.
generateState func() (state.State, error) generateState func() (state.State, error)
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
oidcDiscover func(context.Context, string) (discoveryI, error) getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -152,11 +153,6 @@ type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
type discoveryI interface {
Endpoint() oauth2.Endpoint
Verifier(*oidc.Config) *oidc.IDTokenVerifier
}
// Login performs an OAuth2/OIDC authorization code login using a localhost listener. // Login performs an OAuth2/OIDC authorization code login using a localhost listener.
func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) { func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
h := handlerState{ h := handlerState{
@ -175,9 +171,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generateNonce: nonce.Generate, generateNonce: nonce.Generate,
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { getProvider: upstreamoidc.New,
return oidc.NewProvider(ctx, iss)
},
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -222,16 +216,15 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
} }
// Perform OIDC discovery. // Perform OIDC discovery.
discovered, err := h.oidcDiscover(h.ctx, h.issuer) h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
} }
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{ h.oauth2Config = &oauth2.Config{
ClientID: h.clientID, ClientID: h.clientID,
Endpoint: discovered.Endpoint(), Endpoint: h.provider.Endpoint(),
Scopes: h.scopes, Scopes: h.scopes,
} }
@ -301,7 +294,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations). // some providers do not include one, so we skip the nonce validation here (but not other validations).
return h.validateToken(ctx, refreshed, false) token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "")
if err != nil {
return nil, err
}
return &token, nil
} }
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
@ -328,58 +325,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam) return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
} }
// Exchange the authorization code for access, ID, and refresh tokens. // Exchange the authorization code for access, ID, and refresh tokens and perform required
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier()) // validations on the returned ID token.
token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }
// Perform required validations on the returned ID token. h.callbacks <- callbackResult{token: &token}
token, err := h.validateToken(r.Context(), oauth2Tok, true)
if err != nil {
return err
}
h.callbacks <- callbackResult{token: token}
_, _ = w.Write([]byte("you have been logged in and may now close this tab")) _, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil return nil
} }
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
if err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if checkNonce {
if err := h.nonce.Validate(validated); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
return &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{
Token: tok.AccessToken,
Type: tok.TokenType,
Expiry: metav1.NewTime(tok.Expiry),
},
RefreshToken: &oidctypes.RefreshToken{
Token: tok.RefreshToken,
},
IDToken: &oidctypes.IDToken{
Token: idTok,
Expiry: metav1.NewTime(validated.Expiry),
},
}, nil
}
func (h *handlerState) serve(listener net.Listener) func() { func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))

View File

@ -18,11 +18,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
@ -57,19 +57,9 @@ func TestLogin(t *testing.T) {
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
testToken := oidctypes.Token{ testToken := oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
Token: "test-access-token", RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
},
RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token",
},
IDToken: &oidctypes.IDToken{
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
},
} }
// Start a test server that returns 500 errors // Start a test server that returns 500 errors
@ -78,7 +68,7 @@ func TestLogin(t *testing.T) {
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns a real keyset and answers refresh requests. // Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux() providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux) successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close) t.Cleanup(successServer.Close)
@ -248,6 +238,14 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(testToken, nil, nil)
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
@ -268,12 +266,6 @@ func TestLogin(t *testing.T) {
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil return nil
} }
}, },
@ -285,6 +277,14 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error"))
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
@ -298,16 +298,10 @@ func TestLogin(t *testing.T) {
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil return nil
} }
}, },
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", wantErr: "some validation error",
}, },
{ {
name: "session cache hit but refresh fails", name: "session cache hit but refresh fails",
@ -328,12 +322,6 @@ func TestLogin(t *testing.T) {
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
h.listenAddr = "invalid-listen-address" h.listenAddr = "invalid-listen-address"
return nil return nil
@ -504,7 +492,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
name string name string
method string method string
query string query string
returnIDTok string opt func(t *testing.T) Option
wantErr string wantErr string
wantHTTPStatus int wantHTTPStatus int
}{ }{
@ -530,94 +518,49 @@ func TestHandleAuthCodeCallback(t *testing.T) {
{ {
name: "invalid code", name: "invalid code",
query: "state=test-state&code=invalid", query: "state=test-state&code=invalid",
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", wantErr: "could not complete code exchange: some exchange error",
wantHTTPStatus: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}, opt: func(t *testing.T) Option {
{ return func(h *handlerState) error {
name: "missing ID token", h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
query: "state=test-state&code=valid", mock := mockUpstream(t)
returnIDTok: "", mock.EXPECT().
wantErr: "received response missing ID token", ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
wantHTTPStatus: http.StatusBadRequest, Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
}, return mock
{ }
name: "invalid ID token", return nil
query: "state=test-state&code=valid", }
returnIDTok: "invalid-jwt", },
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid access token hash",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
wantErr: "received invalid ID token: access token hash does not match value in ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid nonce",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
wantHTTPStatus: http.StatusBadRequest,
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
}, },
{ {
name: "valid", name: "valid",
query: "state=test-state&code=valid", query: "state=test-state&code=valid",
opt: func(t *testing.T) Option {
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: return func(h *handlerState) error {
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q", mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
return mock
}
return nil
}
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.NoError(t, r.ParseForm())
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
require.NotEmpty(t, r.Form.Get("code"))
if r.Form.Get("code") != "valid" {
http.Error(w, "invalid authorization code", http.StatusForbidden)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
}
response.AccessToken = "test-access-token"
response.Expiry = time.Now().Add(time.Hour)
response.IDToken = tt.returnIDTok
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}))
t.Cleanup(tokenServer.Close)
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
pkce: pkce.Code("test-pkce"), pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"), nonce: nonce.Nonce("test-nonce"),
oauth2Config: &oauth2.Config{ }
ClientID: "test-client-id", if tt.opt != nil {
RedirectURL: "http://localhost:12345/callback", require.NoError(t, tt.opt(t)(h))
Endpoint: oauth2.Endpoint{
TokenURL: tokenServer.URL,
AuthStyle: oauth2.AuthStyleInParams,
},
},
idTokenVerifier: mockVerifier(),
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -653,34 +596,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
} }
require.NoError(t, result.err) require.NoError(t, result.err)
require.NotNil(t, result.token) require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken.Token, tt.returnIDTok) require.Equal(t, result.token.IDToken.Token, "test-id-token")
} }
}) })
} }
} }
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
func mockVerifier() *oidc.IDTokenVerifier { t.Helper()
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) ctrl := gomock.NewController(t)
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). t.Cleanup(ctrl.Finish)
AnyTimes(). return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
return jws.UnsafePayloadWithoutVerification(), nil
})
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
SkipIssuerCheck: true,
SkipExpiryCheck: true,
SkipClientIDCheck: true,
})
} }
type mockDiscovery struct{ provider *oidc.Provider } // hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
type hasAccessTokenMatcher struct{ expected string }
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
return arg.(*oauth2.Token).AccessToken == m.expected
}
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } func (m hasAccessTokenMatcher) Got(got interface{}) string {
return got.(*oauth2.Token).AccessToken
}
func (m hasAccessTokenMatcher) String() string {
return m.expected
}
func HasAccessToken(expected string) gomock.Matcher {
return hasAccessTokenMatcher{expected: expected}
}