Refactor oidcclient.Login to use new upstreamoidc package.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-11-30 17:14:57 -06:00
parent 4b60c922ef
commit b272b3f331
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
3 changed files with 98 additions and 197 deletions

View File

@ -20,6 +20,10 @@ import (
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
return &ProviderConfig{Config: config, Provider: provider}
}
// ProviderConfig holds the active configuration of an upstream OIDC provider. // ProviderConfig holds the active configuration of an upstream OIDC provider.
type ProviderConfig struct { type ProviderConfig struct {
Name string Name string
@ -31,9 +35,6 @@ type ProviderConfig struct {
} }
} }
// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI.
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
func (p *ProviderConfig) GetName() string { func (p *ProviderConfig) GetName() string {
return p.Name return p.Name
} }

View File

@ -16,10 +16,11 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/pkg/browser" "github.com/pkg/browser"
"golang.org/x/oauth2" "golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/upstreamoidc"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
@ -52,7 +53,7 @@ type handlerState struct {
callbackPath string callbackPath string
// Generated parameters of a login flow. // Generated parameters of a login flow.
idTokenVerifier *oidc.IDTokenVerifier provider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
state state.State state state.State
nonce nonce.Nonce nonce nonce.Nonce
@ -63,7 +64,7 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
oidcDiscover func(context.Context, string) (discoveryI, error) getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -152,11 +153,6 @@ type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
type discoveryI interface {
Endpoint() oauth2.Endpoint
Verifier(*oidc.Config) *oidc.IDTokenVerifier
}
// Login performs an OAuth2/OIDC authorization code login using a localhost listener. // Login performs an OAuth2/OIDC authorization code login using a localhost listener.
func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) { func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
h := handlerState{ h := handlerState{
@ -175,9 +171,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generateNonce: nonce.Generate, generateNonce: nonce.Generate,
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { getProvider: upstreamoidc.New,
return oidc.NewProvider(ctx, iss)
},
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -222,16 +216,15 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
} }
// Perform OIDC discovery. // Perform OIDC discovery.
discovered, err := h.oidcDiscover(h.ctx, h.issuer) h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
} }
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{ h.oauth2Config = &oauth2.Config{
ClientID: h.clientID, ClientID: h.clientID,
Endpoint: discovered.Endpoint(), Endpoint: h.provider.Endpoint(),
Scopes: h.scopes, Scopes: h.scopes,
} }
@ -301,7 +294,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations). // some providers do not include one, so we skip the nonce validation here (but not other validations).
return h.validateToken(ctx, refreshed, false) token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "")
if err != nil {
return nil, err
}
return &token, nil
} }
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
@ -328,58 +325,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam) return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
} }
// Exchange the authorization code for access, ID, and refresh tokens. // Exchange the authorization code for access, ID, and refresh tokens and perform required
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier()) // validations on the returned ID token.
token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }
// Perform required validations on the returned ID token. h.callbacks <- callbackResult{token: &token}
token, err := h.validateToken(r.Context(), oauth2Tok, true)
if err != nil {
return err
}
h.callbacks <- callbackResult{token: token}
_, _ = w.Write([]byte("you have been logged in and may now close this tab")) _, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil return nil
} }
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
if err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if checkNonce {
if err := h.nonce.Validate(validated); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
return &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{
Token: tok.AccessToken,
Type: tok.TokenType,
Expiry: metav1.NewTime(tok.Expiry),
},
RefreshToken: &oidctypes.RefreshToken{
Token: tok.RefreshToken,
},
IDToken: &oidctypes.IDToken{
Token: idTok,
Expiry: metav1.NewTime(validated.Expiry),
},
}, nil
}
func (h *handlerState) serve(listener net.Listener) func() { func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))

View File

@ -18,11 +18,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
@ -57,19 +57,9 @@ func TestLogin(t *testing.T) {
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
testToken := oidctypes.Token{ testToken := oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
Token: "test-access-token", RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
},
RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token",
},
IDToken: &oidctypes.IDToken{
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
},
} }
// Start a test server that returns 500 errors // Start a test server that returns 500 errors
@ -78,7 +68,7 @@ func TestLogin(t *testing.T) {
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns a real keyset and answers refresh requests. // Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux() providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux) successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close) t.Cleanup(successServer.Close)
@ -248,6 +238,14 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(testToken, nil, nil)
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
@ -268,12 +266,6 @@ func TestLogin(t *testing.T) {
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil return nil
} }
}, },
@ -285,6 +277,14 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error"))
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
@ -298,16 +298,10 @@ func TestLogin(t *testing.T) {
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil return nil
} }
}, },
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", wantErr: "some validation error",
}, },
{ {
name: "session cache hit but refresh fails", name: "session cache hit but refresh fails",
@ -328,12 +322,6 @@ func TestLogin(t *testing.T) {
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
h.listenAddr = "invalid-listen-address" h.listenAddr = "invalid-listen-address"
return nil return nil
@ -504,7 +492,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
name string name string
method string method string
query string query string
returnIDTok string opt func(t *testing.T) Option
wantErr string wantErr string
wantHTTPStatus int wantHTTPStatus int
}{ }{
@ -530,94 +518,49 @@ func TestHandleAuthCodeCallback(t *testing.T) {
{ {
name: "invalid code", name: "invalid code",
query: "state=test-state&code=invalid", query: "state=test-state&code=invalid",
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", wantErr: "could not complete code exchange: some exchange error",
wantHTTPStatus: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
return mock
}
return nil
}
}, },
{
name: "missing ID token",
query: "state=test-state&code=valid",
returnIDTok: "",
wantErr: "received response missing ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid ID token",
query: "state=test-state&code=valid",
returnIDTok: "invalid-jwt",
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid access token hash",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
wantErr: "received invalid ID token: access token hash does not match value in ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid nonce",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
wantHTTPStatus: http.StatusBadRequest,
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
}, },
{ {
name: "valid", name: "valid",
query: "state=test-state&code=valid", query: "state=test-state&code=valid",
opt: func(t *testing.T) Option {
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: return func(h *handlerState) error {
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q", mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
return mock
}
return nil
}
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.NoError(t, r.ParseForm())
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
require.NotEmpty(t, r.Form.Get("code"))
if r.Form.Get("code") != "valid" {
http.Error(w, "invalid authorization code", http.StatusForbidden)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
}
response.AccessToken = "test-access-token"
response.Expiry = time.Now().Add(time.Hour)
response.IDToken = tt.returnIDTok
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}))
t.Cleanup(tokenServer.Close)
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
pkce: pkce.Code("test-pkce"), pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"), nonce: nonce.Nonce("test-nonce"),
oauth2Config: &oauth2.Config{ }
ClientID: "test-client-id", if tt.opt != nil {
RedirectURL: "http://localhost:12345/callback", require.NoError(t, tt.opt(t)(h))
Endpoint: oauth2.Endpoint{
TokenURL: tokenServer.URL,
AuthStyle: oauth2.AuthStyleInParams,
},
},
idTokenVerifier: mockVerifier(),
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -653,34 +596,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
} }
require.NoError(t, result.err) require.NoError(t, result.err)
require.NotNil(t, result.token) require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken.Token, tt.returnIDTok) require.Equal(t, result.token.IDToken.Token, "test-id-token")
} }
}) })
} }
} }
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
func mockVerifier() *oidc.IDTokenVerifier { t.Helper()
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) ctrl := gomock.NewController(t)
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). t.Cleanup(ctrl.Finish)
AnyTimes(). return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
return jws.UnsafePayloadWithoutVerification(), nil
})
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
SkipIssuerCheck: true,
SkipExpiryCheck: true,
SkipClientIDCheck: true,
})
} }
type mockDiscovery struct{ provider *oidc.Provider } // hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
type hasAccessTokenMatcher struct{ expected string }
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
return arg.(*oauth2.Token).AccessToken == m.expected
}
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } func (m hasAccessTokenMatcher) Got(got interface{}) string {
return got.(*oauth2.Token).AccessToken
}
func (m hasAccessTokenMatcher) String() string {
return m.expected
}
func HasAccessToken(expected string) gomock.Matcher {
return hasAccessTokenMatcher{expected: expected}
}