diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 72de2e33..2957e9e3 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -20,6 +20,10 @@ import ( "go.pinniped.dev/pkg/oidcclient/pkce" ) +func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + return &ProviderConfig{Config: config, Provider: provider} +} + // ProviderConfig holds the active configuration of an upstream OIDC provider. type ProviderConfig struct { Name string @@ -31,9 +35,6 @@ type ProviderConfig struct { } } -// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI. -var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil) - func (p *ProviderConfig) GetName() string { return p.Name } diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 09d6949f..2e286efa 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -16,10 +16,11 @@ import ( "github.com/coreos/go-oidc" "github.com/pkg/browser" "golang.org/x/oauth2" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -52,18 +53,18 @@ type handlerState struct { callbackPath string // Generated parameters of a login flow. - idTokenVerifier *oidc.IDTokenVerifier - oauth2Config *oauth2.Config - state state.State - nonce nonce.Nonce - pkce pkce.Code + provider *oidc.Provider + oauth2Config *oauth2.Config + state state.State + nonce nonce.Nonce + pkce pkce.Code // External calls for things. generateState func() (state.State, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error - oidcDiscover func(context.Context, string) (discoveryI, error) + getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI callbacks chan callbackResult } @@ -152,11 +153,6 @@ type nopCache struct{} func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} -type discoveryI interface { - Endpoint() oauth2.Endpoint - Verifier(*oidc.Config) *oidc.IDTokenVerifier -} - // Login performs an OAuth2/OIDC authorization code login using a localhost listener. func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) { h := handlerState{ @@ -175,9 +171,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er generateNonce: nonce.Generate, generatePKCE: pkce.Generate, openURL: browser.OpenURL, - oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { - return oidc.NewProvider(ctx, iss) - }, + getProvider: upstreamoidc.New, } for _, opt := range opts { if err := opt(&h); err != nil { @@ -222,16 +216,15 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er } // Perform OIDC discovery. - discovered, err := h.oidcDiscover(h.ctx, h.issuer) + h.provider, err = oidc.NewProvider(h.ctx, h.issuer) if err != nil { return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) } - h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID}) // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. h.oauth2Config = &oauth2.Config{ ClientID: h.clientID, - Endpoint: discovered.Endpoint(), + Endpoint: h.provider.Endpoint(), Scopes: h.scopes, } @@ -301,7 +294,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - return h.validateToken(ctx, refreshed, false) + token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "") + if err != nil { + return nil, err + } + return &token, nil } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { @@ -328,58 +325,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam) } - // Exchange the authorization code for access, ID, and refresh tokens. - oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier()) + // Exchange the authorization code for access, ID, and refresh tokens and perform required + // validations on the returned ID token. + token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } - // Perform required validations on the returned ID token. - token, err := h.validateToken(r.Context(), oauth2Tok, true) - if err != nil { - return err - } - - h.callbacks <- callbackResult{token: token} + h.callbacks <- callbackResult{token: &token} _, _ = w.Write([]byte("you have been logged in and may now close this tab")) return nil } -func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) { - idTok, hasIDTok := tok.Extra("id_token").(string) - if !hasIDTok { - return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") - } - validated, err := h.idTokenVerifier.Verify(ctx, idTok) - if err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - if validated.AccessTokenHash != "" { - if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - } - if checkNonce { - if err := h.nonce.Validate(validated); err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) - } - } - return &oidctypes.Token{ - AccessToken: &oidctypes.AccessToken{ - Token: tok.AccessToken, - Type: tok.TokenType, - Expiry: metav1.NewTime(tok.Expiry), - }, - RefreshToken: &oidctypes.RefreshToken{ - Token: tok.RefreshToken, - }, - IDToken: &oidctypes.IDToken{ - Token: idTok, - Expiry: metav1.NewTime(validated.Expiry), - }, - }, nil -} - func (h *handlerState) serve(listener net.Listener) func() { mux := http.NewServeMux() mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 5bff0142..280dfd0a 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -18,11 +18,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/httputil/httperr" - "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" @@ -57,19 +57,9 @@ func TestLogin(t *testing.T) { require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) testToken := oidctypes.Token{ - AccessToken: &oidctypes.AccessToken{ - Token: "test-access-token", - Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), - }, - RefreshToken: &oidctypes.RefreshToken{ - Token: "test-refresh-token", - }, - IDToken: &oidctypes.IDToken{ - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above): - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775 - Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA", - Expiry: metav1.NewTime(time1.Add(2 * time.Minute)), - }, + AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, + IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))}, } // Start a test server that returns 500 errors @@ -78,7 +68,7 @@ func TestLogin(t *testing.T) { })) t.Cleanup(errorServer.Close) - // Start a test server that returns a real keyset and answers refresh requests. + // Start a test server that returns a real discovery document and answers refresh requests. providerMux := http.NewServeMux() successServer := httptest.NewServer(providerMux) t.Cleanup(successServer.Close) @@ -248,6 +238,14 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). + Return(testToken, nil, nil) + return mock + } + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", @@ -268,12 +266,6 @@ func TestLogin(t *testing.T) { require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) }) h.cache = cache - - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } return nil } }, @@ -285,6 +277,14 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). + Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error")) + return mock + } + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", @@ -298,16 +298,10 @@ func TestLogin(t *testing.T) { }) h.cache = cache - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } - return nil } }, - wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", + wantErr: "some validation error", }, { name: "session cache hit but refresh fails", @@ -328,12 +322,6 @@ func TestLogin(t *testing.T) { }) h.cache = cache - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } - h.listenAddr = "invalid-listen-address" return nil @@ -504,7 +492,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { name string method string query string - returnIDTok string + opt func(t *testing.T) Option wantErr string wantHTTPStatus int }{ @@ -530,94 +518,49 @@ func TestHandleAuthCodeCallback(t *testing.T) { { name: "invalid code", query: "state=test-state&code=invalid", - wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", + wantErr: "could not complete code exchange: some exchange error", wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "missing ID token", - query: "state=test-state&code=valid", - returnIDTok: "", - wantErr: "received response missing ID token", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid ID token", - query: "state=test-state&code=valid", - returnIDTok: "invalid-jwt", - wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid access token hash", - query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg", - - wantErr: "received invalid ID token: access token hash does not match value in ID token", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid nonce", - query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ", - - wantHTTPStatus: http.StatusBadRequest, - wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`, + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) + return mock + } + return nil + } + }, }, { name: "valid", query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) + return mock + } + return nil + } + }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.NoError(t, r.ParseForm()) - require.Equal(t, "test-client-id", r.Form.Get("client_id")) - require.Equal(t, "test-pkce", r.Form.Get("code_verifier")) - require.Equal(t, "authorization_code", r.Form.Get("grant_type")) - require.NotEmpty(t, r.Form.Get("code")) - if r.Form.Get("code") != "valid" { - http.Error(w, "invalid authorization code", http.StatusForbidden) - return - } - var response struct { - oauth2.Token - IDToken string `json:"id_token,omitempty"` - } - response.AccessToken = "test-access-token" - response.Expiry = time.Now().Add(time.Hour) - response.IDToken = tt.returnIDTok - w.Header().Set("content-type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(&response)) - })) - t.Cleanup(tokenServer.Close) - h := &handlerState{ callbacks: make(chan callbackResult, 1), state: state.State("test-state"), pkce: pkce.Code("test-pkce"), nonce: nonce.Nonce("test-nonce"), - oauth2Config: &oauth2.Config{ - ClientID: "test-client-id", - RedirectURL: "http://localhost:12345/callback", - Endpoint: oauth2.Endpoint{ - TokenURL: tokenServer.URL, - AuthStyle: oauth2.AuthStyleInParams, - }, - }, - idTokenVerifier: mockVerifier(), + } + if tt.opt != nil { + require.NoError(t, tt.opt(t)(h)) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -653,34 +596,34 @@ func TestHandleAuthCodeCallback(t *testing.T) { } require.NoError(t, result.err) require.NotNil(t, result.token) - require.Equal(t, result.token.IDToken.Token, tt.returnIDTok) + require.Equal(t, result.token.IDToken.Token, "test-id-token") } }) } } -// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. -func mockVerifier() *oidc.IDTokenVerifier { - mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) - mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). - AnyTimes(). - DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) { - jws, err := jose.ParseSigned(jwt) - if err != nil { - return nil, err - } - return jws.UnsafePayloadWithoutVerification(), nil - }) - - return oidc.NewVerifier("", mockKeySet, &oidc.Config{ - SkipIssuerCheck: true, - SkipExpiryCheck: true, - SkipClientIDCheck: true, - }) +func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI { + t.Helper() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl) } -type mockDiscovery struct{ provider *oidc.Provider } +// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token. +type hasAccessTokenMatcher struct{ expected string } -func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } +func (m hasAccessTokenMatcher) Matches(arg interface{}) bool { + return arg.(*oauth2.Token).AccessToken == m.expected +} -func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } +func (m hasAccessTokenMatcher) Got(got interface{}) string { + return got.(*oauth2.Token).AccessToken +} + +func (m hasAccessTokenMatcher) String() string { + return m.expected +} + +func HasAccessToken(expected string) gomock.Matcher { + return hasAccessTokenMatcher{expected: expected} +}