From 79ca1d7fb053fc6c2cfc1d96233b7964312dca38 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 13 Oct 2021 12:31:20 -0700 Subject: [PATCH] Perform an upstream refresh during downstream refresh for OIDC upstreams - If the upstream refresh fails, then fail the downstream refresh - If the upstream refresh returns an ID token, then validate it (we use its claims in the future, but not in this commit) - If the upstream refresh returns a new refresh token, then save it into the user's session in storage - Pass the provider cache into the token handler so it can use the cached providers to perform upstream refreshes - Handle unexpected errors in the token handler where the user's session does not contain the expected data. These should not be possible in practice unless someone is manually editing the storage, but handle them anyway just to be safe. - Refactor to share the refresh code between the CLI and the token endpoint by moving it into the UpstreamOIDCIdentityProviderI interface, since the token endpoint needed it to be part of that interface anyway --- .../mockupstreamoidcidentityprovider.go | 15 + internal/oidc/oidc.go | 4 +- .../provider/dynamic_upstream_idp_provider.go | 8 + internal/oidc/provider/manager/manager.go | 1 + internal/oidc/token/token_handler.go | 127 +++ internal/oidc/token/token_handler_test.go | 775 +++++++++++++++--- .../testutil/oidctestutil/oidctestutil.go | 180 +++- internal/upstreamoidc/upstreamoidc.go | 14 +- internal/upstreamoidc/upstreamoidc_test.go | 166 ++++ pkg/oidcclient/login.go | 6 +- pkg/oidcclient/login_test.go | 25 +- 11 files changed, 1195 insertions(+), 126 deletions(-) diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index dfbe785b..fd8b7dd9 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -200,6 +200,21 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PasswordCredentialsGran return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCredentialsGrantAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PasswordCredentialsGrantAndValidateTokens), arg0, arg1, arg2) } +// PerformRefresh mocks base method. +func (m *MockUpstreamOIDCIdentityProviderI) PerformRefresh(arg0 context.Context, arg1 string) (*oauth2.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PerformRefresh", arg0, arg1) + ret0, _ := ret[0].(*oauth2.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PerformRefresh indicates an expected call of PerformRefresh. +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1) +} + // ValidateToken mocks base method. func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (*oidctypes.Token, error) { m.ctrl.T.Helper() diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index b70e0309..6c3c1918 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -210,8 +210,6 @@ func FositeOauth2Helper( // The default is to support all prompt values from the spec. // See https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // We'll make a best effort to support these by passing the value of this prompt param to the upstream IDP - // and rely on its implementation of this param. AllowedPromptValues: nil, // Use the fosite default to make it more likely that off the shelf OIDC clients can work with the supervisor. @@ -232,7 +230,7 @@ func FositeOauth2Helper( compose.OpenIDConnectExplicitFactory, compose.OpenIDConnectRefreshFactory, compose.OAuth2PKCEFactory, - TokenExchangeFactory, + TokenExchangeFactory, // handle the "urn:ietf:params:oauth:grant-type:token-exchange" grant type ) provider.(*fosite.Fosite).FormPostHTMLTemplate = formposthtml.Template() return provider diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 6d0af8ec..9ece5f4d 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -63,6 +63,14 @@ type UpstreamOIDCIdentityProviderI interface { redirectURI string, ) (*oidctypes.Token, error) + // PerformRefresh will call the provider's token endpoint to perform a refresh grant. The provider may or may not + // return a new ID or refresh token in the response. If it returns an ID token, then use ValidateRefresh to + // validate the ID token. + PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) + + // ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response + // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated + // tokens, or an error. ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) } diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index ea1d2d62..708d4855 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -130,6 +130,7 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs ) m.providerHandlers[(issuerHostWithPath + oidc.TokenEndpointPath)] = token.NewHandler( + m.upstreamIDPs, oauthHelperWithKubeStorage, ) diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 72222367..724ee0aa 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -5,17 +5,36 @@ package token import ( + "context" "net/http" "github.com/ory/fosite" + "github.com/ory/x/errorsx" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" ) +var ( + errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{ + ErrorField: "error", + DescriptionField: "There was an internal server error.", + HintField: "Required upstream data not found in session.", + CodeField: http.StatusInternalServerError, + } + + errUpstreamRefreshError = &fosite.RFC6749Error{ + ErrorField: "error", + DescriptionField: "Error during upstream refresh.", + CodeField: http.StatusUnauthorized, + } +) + func NewHandler( + idpLister oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { @@ -27,6 +46,20 @@ func NewHandler( return nil } + // Check if we are performing a refresh grant. + if accessRequest.GetGrantTypes().ExactOne("refresh_token") { + // The above call to NewAccessRequest has loaded the session from storage into the accessRequest variable. + // The session, requested scopes, and requested audience from the original authorize request was retrieved + // from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may + // have already been granted on the accessRequest. + err = upstreamRefresh(r.Context(), accessRequest, idpLister) + if err != nil { + plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...) + oauthHelper.WriteAccessError(w, accessRequest, err) + return nil + } + } + accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest) if err != nil { plog.Info("token response error", oidc.FositeErrorForLog(err)...) @@ -39,3 +72,97 @@ func NewHandler( return nil }) } + +func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { + session := accessRequest.GetSession().(*psession.PinnipedSession) + customSessionData := session.Custom + if customSessionData == nil { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + providerName := customSessionData.ProviderName + providerUID := customSessionData.ProviderUID + if providerUID == "" || providerName == "" { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + + switch customSessionData.ProviderType { + case psession.ProviderTypeOIDC: + return upstreamOIDCRefresh(ctx, customSessionData, providerCache) + case psession.ProviderTypeLDAP: + // upstream refresh not yet implemented for LDAP, so do nothing + case psession.ProviderTypeActiveDirectory: + // upstream refresh not yet implemented for AD, so do nothing + default: + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + + return nil +} + +func upstreamOIDCRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error { + if s.OIDC == nil || s.OIDC.UpstreamRefreshToken == "" { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + + p, err := findOIDCProviderByNameAndValidateUID(s, providerCache) + if err != nil { + return err + } + + plog.Debug("attempting upstream refresh request", + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + + refreshedTokens, err := p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) + if err != nil { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh failed using provider %q of type %q.", + s.ProviderName, s.ProviderType).WithWrap(err)) + } + + // Upstream refresh may or may not return a new ID token. From the spec: + // "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token." + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse + _, hasIDTok := refreshedTokens.Extra("id_token").(string) + if hasIDTok { + // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at + // least some providers do not include one, so we skip the nonce validation here (but not other validations). + _, err = p.ValidateToken(ctx, refreshedTokens, "") + if err != nil { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh returned an invalid ID token using provider %q of type %q.", + s.ProviderName, s.ProviderType).WithWrap(err)) + } + } else { + plog.Debug("upstream refresh request did not return a new ID token", + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + } + + // Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in + // the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding + // overwriting the old one. + if refreshedTokens.RefreshToken != "" { + plog.Debug("upstream refresh request did not return a new refresh token", + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + s.OIDC.UpstreamRefreshToken = refreshedTokens.RefreshToken + } + + return nil +} + +func findOIDCProviderByNameAndValidateUID( + s *psession.CustomSessionData, + providerCache oidc.UpstreamIdentityProvidersLister, +) (provider.UpstreamOIDCIdentityProviderI, error) { + for _, p := range providerCache.GetOIDCIdentityProviders() { + if p.GetName() == s.ProviderName { + if p.GetResourceUID() != s.ProviderUID { + return nil, errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Provider %q of type %q from upstream session data has changed its resource UID since authentication.", + s.ProviderName, s.ProviderType)) + } + return p, nil + } + } + return nil, errorsx.WithStack(errUpstreamRefreshError. + WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) +} diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index 262f5756..968944d7 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -23,12 +23,13 @@ import ( "time" "github.com/ory/fosite" - "github.com/ory/fosite/handler/oauth2" + fositeoauth2 "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/pkce" "github.com/ory/fosite/token/jwt" "github.com/pkg/errors" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" josejwt "gopkg.in/square/go-jose.v2/jwt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -44,8 +45,10 @@ import ( "go.pinniped.dev/internal/fositestorage/refreshtoken" "go.pinniped.dev/internal/fositestoragei" "go.pinniped.dev/internal/here" + "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/jwks" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -179,6 +182,13 @@ var ( } `) + pinnipedUpstreamSessionDataNotFoundErrorBody = here.Doc(` + { + "error": "error", + "error_description": "There was an internal server error. Required upstream data not found in session." + } + `) + happyAuthRequest = &http.Request{ Form: url.Values{ "response_type": {"code"}, @@ -206,12 +216,25 @@ var ( } ) +type expectedUpstreamRefresh struct { + performedByUpstreamName string + args *oidctestutil.PerformRefreshArgs +} + +type expectedUpstreamValidateTokens struct { + performedByUpstreamName string + args *oidctestutil.ValidateTokenArgs +} + type tokenEndpointResponseExpectedValues struct { - wantStatus int - wantSuccessBodyFields []string - wantErrorResponseBody string - wantRequestedScopes []string - wantGrantedScopes []string + wantStatus int + wantSuccessBodyFields []string + wantErrorResponseBody string + wantRequestedScopes []string + wantGrantedScopes []string + wantUpstreamOIDCRefreshCall *expectedUpstreamRefresh + wantUpstreamOIDCValidateTokenCall *expectedUpstreamValidateTokens + wantCustomSessionDataStored *psession.CustomSessionData } type authcodeExchangeInputs struct { @@ -222,13 +245,9 @@ type authcodeExchangeInputs struct { s fositestoragei.AllFositeStorage, authCode string, ) - makeOathHelper func( - t *testing.T, - authRequest *http.Request, - store fositestoragei.AllFositeStorage, - ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) - - want tokenEndpointResponseExpectedValues + makeOathHelper OauthHelperFactoryFunc + customSessionData *psession.CustomSessionData + want tokenEndpointResponseExpectedValues } func TestTokenEndpoint(t *testing.T) { @@ -520,7 +539,8 @@ func TestTokenEndpoint(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - exchangeAuthcodeForTokens(t, test.authcodeExchange) + // Authcode exchange doesn't use the upstream provider cache, so just pass an empty cache. + exchangeAuthcodeForTokens(t, test.authcodeExchange, oidctestutil.NewUpstreamIDPListerBuilder().Build()) }) } } @@ -549,7 +569,9 @@ func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) { t.Parallel() // First call - should be successful. - subject, rsp, authCode, _, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange) + // Authcode exchange doesn't use the upstream provider cache, so just pass an empty cache. + subject, rsp, authCode, _, secrets, oauthStore := exchangeAuthcodeForTokens(t, + test.authcodeExchange, oidctestutil.NewUpstreamIDPListerBuilder().Build()) var parsedResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody)) @@ -573,9 +595,11 @@ func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) { requireInvalidAccessTokenStorage(t, parsedResponseBody, oauthStore) // This was previously invalidated by the first request, so it remains invalidated requireInvalidPKCEStorage(t, authCode, oauthStore) - // Fosite never cleans up OpenID Connect session storage, so it is still there + // Fosite never cleans up OpenID Connect session storage, so it is still there. + // Note that customSessionData is only relevant to refresh grant, so we leave it as nil for this + // authcode exchange test, even though in practice it would actually be in the session. requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, - test.authcodeExchange.want.wantRequestedScopes, test.authcodeExchange.want.wantGrantedScopes) + test.authcodeExchange.want.wantRequestedScopes, test.authcodeExchange.want.wantGrantedScopes, nil) // Check that the access token and refresh token storage were both deleted, and the number of other storage objects did not change. testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) @@ -602,6 +626,7 @@ func TestTokenExchange(t *testing.T) { }, want: successfulAuthCodeExchange, } + tests := []struct { name string @@ -742,7 +767,9 @@ func TestTokenExchange(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange) + // Authcode exchange doesn't use the upstream provider cache, so just pass an empty cache. + subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, + test.authcodeExchange, oidctestutil.NewUpstreamIDPListerBuilder().Build()) var parsedAuthcodeExchangeResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody)) @@ -845,80 +872,219 @@ type refreshRequestInputs struct { } func TestRefreshGrant(t *testing.T) { + const ( + oidcUpstreamName = "some-oidc-idp" + oidcUpstreamResourceUID = "oidc-resource-uid" + oidcUpstreamType = "oidc" + oidcUpstreamInitialRefreshToken = "initial-upstream-refresh-token" + oidcUpstreamRefreshedIDToken = "fake-refreshed-id-token" + oidcUpstreamRefreshedRefreshToken = "fake-refreshed-refresh-token" + ) + + // The below values are funcs so every test can have its own copy of the objects, to avoid data races + // in these parallel tests. + + upstreamOIDCIdentityProviderBuilder := func() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { + return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName(oidcUpstreamName). + WithResourceUID(oidcUpstreamResourceUID) + } + + initialUpstreamOIDCCustomSessionData := func() *psession.CustomSessionData { + return &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamRefreshToken: oidcUpstreamInitialRefreshToken, + }, + } + } + + upstreamOIDCCustomSessionDataWithNewRefreshToken := func(newRefreshToken string) *psession.CustomSessionData { + sessionData := initialUpstreamOIDCCustomSessionData() + sessionData.OIDC.UpstreamRefreshToken = newRefreshToken + return sessionData + } + + happyUpstreamRefreshCall := func() *expectedUpstreamRefresh { + return &expectedUpstreamRefresh{ + performedByUpstreamName: oidcUpstreamName, + args: &oidctestutil.PerformRefreshArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + RefreshToken: oidcUpstreamInitialRefreshToken, + }, + } + } + + happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token) *expectedUpstreamValidateTokens { + return &expectedUpstreamValidateTokens{ + performedByUpstreamName: oidcUpstreamName, + args: &oidctestutil.ValidateTokenArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + Tok: expectedTokens, + ExpectedIDTokenNonce: "", // always expect empty string + }, + } + } + + happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess := func(wantCustomSessionDataStored *psession.CustomSessionData) tokenEndpointResponseExpectedValues { + want := tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access"}, + wantGrantedScopes: []string{"openid", "offline_access"}, + wantCustomSessionDataStored: wantCustomSessionDataStored, + } + return want + } + + happyRefreshTokenResponseForOpenIDAndOfflineAccess := func(wantCustomSessionDataStored *psession.CustomSessionData, expectToValidateToken *oauth2.Token) tokenEndpointResponseExpectedValues { + // Should always have some custom session data stored. The other expectations happens to be the + // same as the same values as the authcode exchange case. + want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored) + // Should always try to perform an upstream refresh. + want.wantUpstreamOIDCRefreshCall = happyUpstreamRefreshCall() + // Should only try to ValidateToken when there was an id token returned by the upstream refresh. + if expectToValidateToken != nil { + want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken) + } + return want + } + + refreshedUpstreamTokensWithRefreshTokenWithoutIDToken := func() *oauth2.Token { + return &oauth2.Token{ + AccessToken: "fake-refreshed-access-token", + TokenType: "Bearer", + RefreshToken: oidcUpstreamRefreshedRefreshToken, + Expiry: time.Date(2050, 1, 1, 1, 1, 1, 1, time.UTC), + } + } + + refreshedUpstreamTokensWithIDAndRefreshTokens := func() *oauth2.Token { + return refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(). + WithExtra(map[string]interface{}{"id_token": oidcUpstreamRefreshedIDToken}) + } + + refreshedUpstreamTokensWithIDTokenWithoutRefreshToken := func() *oauth2.Token { + tokens := refreshedUpstreamTokensWithIDAndRefreshTokens() + tokens.RefreshToken = "" // remove the refresh token + return tokens + } + tests := []struct { name string + idps *oidctestutil.UpstreamIDPListerBuilder authcodeExchange authcodeExchangeInputs refreshRequest refreshRequestInputs }{ { - name: "happy path refresh grant with ID token", + name: "happy path refresh grant with openid scope granted (id token returned)", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), }, refreshRequest: refreshRequestInputs{ - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }}, + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + refreshedUpstreamTokensWithIDAndRefreshTokens(), + ), + }, }, { - name: "happy path refresh grant without ID token", + name: "happy path refresh grant without openid scope granted (no id token returned)", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, - }}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + }, + }, + }, + { + name: "happy path refresh grant when the upstream refresh does not return a new ID token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + nil, // expect ValidateToken is *not* called + ), + }, + }, + { + name: "happy path refresh grant when the upstream refresh does not return a new refresh token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDTokenWithoutRefreshToken()).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + initialUpstreamOIDCCustomSessionData(), // still has the initial refresh token stored + refreshedUpstreamTokensWithIDTokenWithoutRefreshToken(), + ), + }, }, { name: "when the refresh request adds a new scope to the list of requested scopes then it is ignored", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), }, refreshRequest: refreshRequestInputs{ modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { r.Body = happyRefreshRequestBody(refreshToken).WithScope("openid some-other-scope-not-from-auth-request").ReadCloser() }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }}, + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + refreshedUpstreamTokensWithIDAndRefreshTokens(), + ), + }, }, { name: "when the refresh request removes a scope which was originally granted from the list of requested scopes then it is granted anyway", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access pinniped:request-audience") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -926,43 +1092,47 @@ func TestRefreshGrant(t *testing.T) { r.Body = happyRefreshRequestBody(refreshToken).WithScope("openid").ReadCloser() // do not ask for "pinniped:request-audience" again }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - }}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + }, + }, }, { name: "when the refresh request does not include a scope param then it gets all the same scopes as the original authorization request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), }, refreshRequest: refreshRequestInputs{ modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { r.Body = happyRefreshRequestBody(refreshToken).WithScope("").ReadCloser() }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }}, + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + refreshedUpstreamTokensWithIDAndRefreshTokens(), + ), + }, }, { name: "when a bad refresh token is sent in the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -972,17 +1142,21 @@ func TestRefreshGrant(t *testing.T) { want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusBadRequest, wantErrorResponseBody: fositeInvalidAuthCodeErrorBody, - }}, + }, + }, }, { name: "when the access token is sent as if it were a refresh token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -992,17 +1166,21 @@ func TestRefreshGrant(t *testing.T) { want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusBadRequest, wantErrorResponseBody: fositeInvalidAuthCodeErrorBody, - }}, + }, + }, }, { name: "when the wrong client ID is included in the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1012,7 +1190,301 @@ func TestRefreshGrant(t *testing.T) { want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusUnauthorized, wantErrorResponseBody: fositeInvalidClientErrorBody, - }}, + }, + }, + }, + { + name: "when there is no custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: nil, // this should not happen in practice + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(nil), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no provider name in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: "", // this should not happen in practice + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: "", // this should not happen in practice + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no provider UID in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: "", // this should not happen in practice + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: "", // this should not happen in practice + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no provider type in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is an illegal provider type in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "not-an-allowed-provider-type", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "not-an-allowed-provider-type", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no OIDC-specific data in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: nil, // this should not happen in practice + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: nil, // this should not happen in practice + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no OIDC refresh token in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamRefreshToken: "", // this should not happen in practice + }, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamRefreshToken: "", // this should not happen in practice + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when the provider in the session storage is not found due to its name during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: "this-name-will-not-be-found", // this could happen if the OIDCIdentityProvider was deleted since original login + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: "this-name-will-not-be-found", // this could happen if the OIDCIdentityProvider was deleted since original login + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Provider 'this-name-will-not-be-found' of type 'oidc' from upstream session data was not found." + } + `), + }, + }, + }, + { + name: "when the provider in the session storage is found but has the wrong resource UID during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: "this is the wrong uid", // this could happen if the OIDCIdentityProvider was deleted and recreated at the same name since original login + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: "this is the wrong uid", // this could happen if the OIDCIdentityProvider was deleted and recreated at the same name since original login + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Provider 'some-oidc-idp' of type 'oidc' from upstream session data has changed its resource UID since authentication." + } + `), + }, + }, + }, + { + name: "when the upstream refresh fails during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). + WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed using provider 'some-oidc-idp' of type 'oidc'." + } + `), + }, + }, + }, + { + name: "when the upstream refresh returns an invalid ID token during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). + WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). + // This is the current format of the errors returned by the production code version of ValidateToken, see ValidateToken in upstreamoidc.go + WithValidateTokenError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))). + Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh returned an invalid ID token using provider 'some-oidc-idp' of type 'oidc'." + } + `), + }, + }, }, } for _, test := range tests { @@ -1021,10 +1493,15 @@ func TestRefreshGrant(t *testing.T) { t.Parallel() // First exchange the authcode for tokens, including a refresh token. - subject, rsp, authCode, jwtSigningKey, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange) + subject, rsp, authCode, jwtSigningKey, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange, test.idps.Build()) var parsedAuthcodeExchangeResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody)) + // Performing an authcode exchange should not have caused any upstream refresh, which should only + // happen during a downstream refresh. + test.idps.RequireExactlyZeroCallsToPerformRefresh(t) + test.idps.RequireExactlyZeroCallsToValidateToken(t) + // Wait one second before performing the refresh so we can see that the refreshed ID token has new issued // at and expires at dates which are newer than the old tokens. // If this gets too annoying in terms of making our test suite slower then we can remove it and adjust @@ -1033,8 +1510,10 @@ func TestRefreshGrant(t *testing.T) { // Send the refresh token back and preform a refresh. firstRefreshToken := parsedAuthcodeExchangeResponseBody["refresh_token"].(string) + require.NotEmpty(t, firstRefreshToken) + reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") req := httptest.NewRequest("POST", "/path/shouldn't/matter", - happyRefreshRequestBody(firstRefreshToken).ReadCloser()) + happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") if test.refreshRequest.modifyTokenRequest != nil { test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string)) @@ -1045,11 +1524,45 @@ func TestRefreshGrant(t *testing.T) { t.Logf("second response: %#v", refreshResponse) t.Logf("second response body: %q", refreshResponse.Body.String()) + // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. + if test.refreshRequest.want.wantUpstreamOIDCRefreshCall != nil { + test.refreshRequest.want.wantUpstreamOIDCRefreshCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToPerformRefresh(t, + test.refreshRequest.want.wantUpstreamOIDCRefreshCall.performedByUpstreamName, + test.refreshRequest.want.wantUpstreamOIDCRefreshCall.args, + ) + } else { + test.idps.RequireExactlyZeroCallsToPerformRefresh(t) + } + + // Test that we did or did not make a call to the upstream OIDC provider interface to validate the + // new ID token that was returned by the upstream refresh. + if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil { + test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToValidateToken(t, + test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName, + test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args, + ) + } else { + test.idps.RequireExactlyZeroCallsToValidateToken(t) + } + // The bug in fosite that prevents at_hash from appearing in the initial ID token does not impact the refreshed ID token wantAtHashClaimInIDToken := true // Refreshed ID tokens do not include the nonce from the original auth request wantNonceValueInIDToken := false - requireTokenEndpointBehavior(t, test.refreshRequest.want, wantAtHashClaimInIDToken, wantNonceValueInIDToken, refreshResponse, authCode, oauthStore, jwtSigningKey, secrets) + + requireTokenEndpointBehavior(t, + test.refreshRequest.want, + test.authcodeExchange.customSessionData, + wantAtHashClaimInIDToken, + wantNonceValueInIDToken, + refreshResponse, + authCode, + oauthStore, + jwtSigningKey, + secrets, + ) if test.refreshRequest.want.wantStatus == http.StatusOK { wantIDToken := contains(test.refreshRequest.want.wantSuccessBodyFields, "id_token") @@ -1109,7 +1622,7 @@ func requireClaimsAreEqual(t *testing.T, claimName string, claimsOfTokenA map[st require.Equal(t, claimsOfTokenA[claimName], claimsOfTokenB[claimName]) } -func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( +func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps provider.DynamicUpstreamIDPProvider) ( subject http.Handler, rsp *httptest.ResponseRecorder, authCode string, @@ -1129,15 +1642,17 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( oauthStore = oidc.NewKubeStorage(secrets, oidc.DefaultOIDCTimeoutsConfiguration()) if test.makeOathHelper != nil { - oauthHelper, authCode, jwtSigningKey = test.makeOathHelper(t, authRequest, oauthStore) + oauthHelper, authCode, jwtSigningKey = test.makeOathHelper(t, authRequest, oauthStore, test.customSessionData) } else { - oauthHelper, authCode, jwtSigningKey = makeHappyOauthHelper(t, authRequest, oauthStore) + // Note that makeHappyOauthHelper() calls simulateAuthEndpointHavingAlreadyRun() to preload the session storage. + oauthHelper, authCode, jwtSigningKey = makeHappyOauthHelper(t, authRequest, oauthStore, test.customSessionData) } if test.modifyStorage != nil { test.modifyStorage(t, oauthStore, authCode) } - subject = NewHandler(oauthHelper) + + subject = NewHandler(idps, oauthHelper) authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid") expectedNumberOfIDSessionsStored := 0 @@ -1163,7 +1678,18 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( wantAtHashClaimInIDToken := false // due to a bug in fosite, the at_hash claim is not filled in during authcode exchange wantNonceValueInIDToken := true // ID tokens returned by the authcode exchange must include the nonce from the auth request (unliked refreshed ID tokens) - requireTokenEndpointBehavior(t, test.want, wantAtHashClaimInIDToken, wantNonceValueInIDToken, rsp, authCode, oauthStore, jwtSigningKey, secrets) + + requireTokenEndpointBehavior(t, + test.want, + test.customSessionData, + wantAtHashClaimInIDToken, + wantNonceValueInIDToken, + rsp, + authCode, + oauthStore, + jwtSigningKey, + secrets, + ) return subject, rsp, authCode, jwtSigningKey, secrets, oauthStore } @@ -1171,6 +1697,7 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( func requireTokenEndpointBehavior( t *testing.T, test tokenEndpointResponseExpectedValues, + oldCustomSessionData *psession.CustomSessionData, wantAtHashClaimInIDToken bool, wantNonceValueInIDToken bool, tokenEndpointResponse *httptest.ResponseRecorder, @@ -1193,9 +1720,10 @@ func requireTokenEndpointBehavior( wantRefreshToken := contains(test.wantSuccessBodyFields, "refresh_token") requireInvalidAuthCodeStorage(t, authCode, oauthStore, secrets) - requireValidAccessTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, secrets) + requireValidAccessTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, test.wantCustomSessionDataStored, secrets) requireInvalidPKCEStorage(t, authCode, oauthStore) - requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes) + // Performing a refresh does not update the OIDC storage, so after a refresh it should still have the old custom session data from the initial login. + requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, oldCustomSessionData) expectedNumberOfRefreshTokenSessionsStored := 0 if wantRefreshToken { @@ -1207,7 +1735,7 @@ func requireTokenEndpointBehavior( requireValidIDToken(t, parsedResponseBody, jwtSigningKey, wantAtHashClaimInIDToken, wantNonceValueInIDToken, parsedResponseBody["access_token"].(string)) } if wantRefreshToken { - requireValidRefreshTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, secrets) + requireValidRefreshTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, test.wantCustomSessionDataStored, secrets) } testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) @@ -1304,16 +1832,24 @@ func getFositeDataSignature(t *testing.T, data string) string { return split[1] } +type OauthHelperFactoryFunc func( + t *testing.T, + authRequest *http.Request, + store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, +) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) + func makeHappyOauthHelper( t *testing.T, authRequest *http.Request, store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { t.Helper() jwtSigningKey, jwkProvider := generateJWTSigningKeyAndJWKSProvider(t, goodIssuer) oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, hmacSecretFunc, jwkProvider, oidc.DefaultOIDCTimeoutsConfiguration()) - authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper) + authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper, initialCustomSessionData) return oauthHelper, authResponder.GetCode(), jwtSigningKey } @@ -1334,12 +1870,13 @@ func makeOauthHelperWithJWTKeyThatWorksOnlyOnce( t *testing.T, authRequest *http.Request, store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { t.Helper() jwtSigningKey, jwkProvider := generateJWTSigningKeyAndJWKSProvider(t, goodIssuer) oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, hmacSecretFunc, &singleUseJWKProvider{DynamicJWKSProvider: jwkProvider}, oidc.DefaultOIDCTimeoutsConfiguration()) - authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper) + authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper, initialCustomSessionData) return oauthHelper, authResponder.GetCode(), jwtSigningKey } @@ -1347,17 +1884,23 @@ func makeOauthHelperWithNilPrivateJWTSigningKey( t *testing.T, authRequest *http.Request, store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { t.Helper() jwkProvider := jwks.NewDynamicJWKSProvider() // empty provider which contains no signing key for this issuer oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, hmacSecretFunc, jwkProvider, oidc.DefaultOIDCTimeoutsConfiguration()) - authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper) + authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper, initialCustomSessionData) return oauthHelper, authResponder.GetCode(), nil } // Simulate the auth endpoint running so Fosite code will fill the store with realistic values. -func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Request, oauthHelper fosite.OAuth2Provider) fosite.AuthorizeResponder { +func simulateAuthEndpointHavingAlreadyRun( + t *testing.T, + authRequest *http.Request, + oauthHelper fosite.OAuth2Provider, + initialCustomSessionData *psession.CustomSessionData, +) fosite.AuthorizeResponder { // We only set the fields in the session that Fosite wants us to set. ctx := context.Background() session := &psession.PinnipedSession{ @@ -1374,11 +1917,7 @@ func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Reques Subject: "", // not used, note that callback_handler.go does not set this Username: "", // not used, note that callback_handler.go does not set this }, - Custom: &psession.CustomSessionData{ - OIDC: &psession.OIDCSessionData{ - UpstreamRefreshToken: "starting-fake-refresh-token", - }, - }, + Custom: initialCustomSessionData, } authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest) require.NoError(t, err) @@ -1416,7 +1955,7 @@ func generateJWTSigningKeyAndJWKSProvider(t *testing.T, issuer string) (*ecdsa.P func requireInvalidAuthCodeStorage( t *testing.T, code string, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, secrets v1.SecretInterface, ) { t.Helper() @@ -1431,9 +1970,10 @@ func requireInvalidAuthCodeStorage( func requireValidRefreshTokenStorage( t *testing.T, body map[string]interface{}, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, wantRequestedScopes []string, wantGrantedScopes []string, + wantCustomSessionData *psession.CustomSessionData, secrets v1.SecretInterface, ) { t.Helper() @@ -1455,6 +1995,7 @@ func requireValidRefreshTokenStorage( wantRequestedScopes, wantGrantedScopes, true, + wantCustomSessionData, ) requireGarbageCollectTimeInDelta(t, refreshTokenString, "refresh-token", secrets, time.Now().Add(9*time.Hour).Add(2*time.Minute), 1*time.Minute) @@ -1463,9 +2004,10 @@ func requireValidRefreshTokenStorage( func requireValidAccessTokenStorage( t *testing.T, body map[string]interface{}, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, wantRequestedScopes []string, wantGrantedScopes []string, + wantCustomSessionData *psession.CustomSessionData, secrets v1.SecretInterface, ) { t.Helper() @@ -1506,6 +2048,7 @@ func requireValidAccessTokenStorage( wantRequestedScopes, wantGrantedScopes, true, + wantCustomSessionData, ) requireGarbageCollectTimeInDelta(t, accessTokenString, "access-token", secrets, time.Now().Add(9*time.Hour).Add(2*time.Minute), 1*time.Minute) @@ -1514,7 +2057,7 @@ func requireValidAccessTokenStorage( func requireInvalidAccessTokenStorage( t *testing.T, body map[string]interface{}, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, ) { t.Helper() @@ -1547,6 +2090,7 @@ func requireValidOIDCStorage( storage openid.OpenIDConnectRequestStorage, wantRequestedScopes []string, wantGrantedScopes []string, + wantCustomSessionData *psession.CustomSessionData, ) { t.Helper() @@ -1569,6 +2113,7 @@ func requireValidOIDCStorage( wantRequestedScopes, wantGrantedScopes, false, + wantCustomSessionData, ) } else { _, err := storage.GetOpenIDConnectSession(context.Background(), code, nil) @@ -1583,6 +2128,7 @@ func requireValidStoredRequest( wantRequestedScopes []string, wantGrantedScopes []string, wantAccessTokenExpiresAt bool, + wantCustomSessionData *psession.CustomSessionData, ) { t.Helper() @@ -1669,6 +2215,9 @@ func requireValidStoredRequest( // We don't use these, so they should be empty. require.Empty(t, session.Fosite.Username) require.Empty(t, session.Fosite.Subject) + + // The custom session data was stored as expected. + require.Equal(t, wantCustomSessionData, session.Custom) } func requireGarbageCollectTimeInDelta(t *testing.T, tokenString string, typeLabel string, secrets v1.SecretInterface, wantExpirationTime time.Time, deltaTime time.Duration) { diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 3f8c3f57..90361ddd 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -58,6 +58,21 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct { Password string } +// PerformRefreshArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). +type PerformRefreshArgs struct { + Ctx context.Context + RefreshToken string +} + +// ValidateTokenArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.ValidateTokenFunc(). +type ValidateTokenArgs struct { + Ctx context.Context + Tok *oauth2.Token + ExpectedIDTokenNonce nonce.Nonce +} + type TestUpstreamLDAPIdentityProvider struct { Name string ResourceUID types.UID @@ -107,10 +122,18 @@ type TestUpstreamOIDCIdentityProvider struct { password string, ) (*oidctypes.Token, error) + PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error) + + ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) + exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs passwordCredentialsGrantAndValidateTokensCallCount int passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs + performRefreshCallCount int + performRefreshArgs []*PerformRefreshArgs + validateTokenCallCount int + validateTokenArgs []*ValidateTokenArgs } var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} @@ -193,8 +216,51 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs return u.exchangeAuthcodeAndValidateTokensArgs[call] } -func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(_ context.Context, _ *oauth2.Token, _ nonce.Nonce) (*oidctypes.Token, error) { - panic("implement me") +func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + if u.performRefreshArgs == nil { + u.performRefreshArgs = make([]*PerformRefreshArgs, 0) + } + u.performRefreshCallCount++ + u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ + Ctx: ctx, + RefreshToken: refreshToken, + }) + return u.PerformRefreshFunc(ctx, refreshToken) +} + +func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int { + return u.performRefreshCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs { + if u.performRefreshArgs == nil { + u.performRefreshArgs = make([]*PerformRefreshArgs, 0) + } + return u.performRefreshArgs[call] +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.validateTokenArgs == nil { + u.validateTokenArgs = make([]*ValidateTokenArgs, 0) + } + u.validateTokenCallCount++ + u.validateTokenArgs = append(u.validateTokenArgs, &ValidateTokenArgs{ + Ctx: ctx, + Tok: tok, + ExpectedIDTokenNonce: expectedIDTokenNonce, + }) + return u.ValidateTokenFunc(ctx, tok, expectedIDTokenNonce) +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenCallCount() int { + return u.validateTokenCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenArgs(call int) *ValidateTokenArgs { + if u.validateTokenArgs == nil { + u.validateTokenArgs = make([]*ValidateTokenArgs, 0) + } + return u.validateTokenArgs[call] } type UpstreamIDPListerBuilder struct { @@ -316,6 +382,80 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV ) } +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *PerformRefreshArgs, +) { + t.Helper() + var actualArgs *PerformRefreshArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount + actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualArgs = upstreamOIDC.performRefreshArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, + "should have been exactly one call to PerformRefresh() by all OIDC upstreams", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "PerformRefresh() was called on the wrong OIDC upstream", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) { + t.Helper() + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.performRefreshCallCount + } + require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + "expected exactly zero calls to PerformRefresh()", + ) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *ValidateTokenArgs, +) { + t.Helper() + var actualArgs *ValidateTokenArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstreamOIDC.validateTokenCallCount + actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualArgs = upstreamOIDC.validateTokenArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, + "should have been exactly one call to ValidateToken() by all OIDC upstreams", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "ValidateToken() was called on the wrong OIDC upstream", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *testing.T) { + t.Helper() + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount + } + require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + "expected exactly zero calls to ValidateToken()", + ) +} + func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder { return &UpstreamIDPListerBuilder{} } @@ -329,11 +469,15 @@ type TestUpstreamOIDCIdentityProviderBuilder struct { refreshToken *oidctypes.RefreshToken usernameClaim string groupsClaim string + refreshedTokens *oauth2.Token + validatedTokens *oidctypes.Token authorizationURL url.URL additionalAuthcodeParams map[string]string allowPasswordGrant bool authcodeExchangeErr error passwordGrantErr error + performRefreshErr error + validateTokenErr error } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder { @@ -429,6 +573,26 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPasswordGrantError(err err return u } +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRefreshedTokens(tokens *oauth2.Token) *TestUpstreamOIDCIdentityProviderBuilder { + u.refreshedTokens = tokens + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.performRefreshErr = err + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder { + u.validatedTokens = tokens + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.validateTokenErr = err + return u +} + func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdentityProvider { return &TestUpstreamOIDCIdentityProvider{ Name: u.name, @@ -452,6 +616,18 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent } return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}, RefreshToken: u.refreshToken}, nil }, + PerformRefreshFunc: func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + if u.performRefreshErr != nil { + return nil, u.performRefreshErr + } + return u.refreshedTokens, nil + }, + ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.validateTokenErr != nil { + return nil, u.validateTokenErr + } + return u.validatedTokens, nil + }, } } diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 17e70915..34c27be2 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -47,6 +47,8 @@ type ProviderConfig struct { } } +var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil) + func (p *ProviderConfig) GetResourceUID() types.UID { return p.ResourceUID } @@ -120,6 +122,14 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, return p.ValidateToken(ctx, tok, expectedIDTokenNonce) } +func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Create a TokenSource without an access token, so it thinks that a refresh is immediately required. + // Then ask it for the tokens to cause it to perform the refresh and return the results. + return p.Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}).Token() +} + +// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response, +// if the provider offers the userinfo endpoint. func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { @@ -146,7 +156,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e } maybeLogClaims("claims from ID token", p.Name, validatedClaims) - if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil { + if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims); err != nil { return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err) } @@ -167,7 +177,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e }, nil } -func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error { +func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error { idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string) if len(idTokenSubject) == 0 { return nil // defer to existing ID token validation diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 2ffd40ca..342683fc 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -23,6 +23,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" ) @@ -288,6 +289,171 @@ func TestProviderConfig(t *testing.T) { } }) + t.Run("PerformRefresh", func(t *testing.T) { + tests := []struct { + name string + returnIDTok string + returnAccessTok string + returnRefreshTok string + returnTokType string + returnExpiresIn string + tokenStatusCode int + + wantErr string + wantToken *oauth2.Token + wantTokenExtras map[string]interface{} + }{ + { + name: "success when the server returns all tokens in the refresh result", + returnIDTok: "test-id-token", + returnAccessTok: "test-access-token", + returnRefreshTok: "test-refresh-token", + returnTokType: "test-token-type", + returnExpiresIn: "42", + tokenStatusCode: http.StatusOK, + wantToken: &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "test-token-type", + Expiry: time.Now().Add(42 * time.Second), + }, + wantTokenExtras: map[string]interface{}{ + // the ID token only appears in the extras map + "id_token": "test-id-token", + // the library also repeats all the other keys/values returned by the server in the raw extras map + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "token_type": "test-token-type", + "expires_in": "42", + // the library also adds this zero-value even though the server did not return it + "expiry": "0001-01-01T00:00:00Z", + }, + }, + { + name: "success when the server does not return a new refresh token in the refresh result", + returnIDTok: "test-id-token", + returnAccessTok: "test-access-token", + returnRefreshTok: "", + returnTokType: "test-token-type", + returnExpiresIn: "42", + tokenStatusCode: http.StatusOK, + wantToken: &oauth2.Token{ + AccessToken: "test-access-token", + // the library sets the original refresh token into the result, even though the server did not return that + RefreshToken: "test-initial-refresh-token", + TokenType: "test-token-type", + Expiry: time.Now().Add(42 * time.Second), + }, + wantTokenExtras: map[string]interface{}{ + // the ID token only appears in the extras map + "id_token": "test-id-token", + // the library also repeats all the other keys/values returned by the server in the raw extras map + "access_token": "test-access-token", + "token_type": "test-token-type", + "expires_in": "42", + // the library also adds this zero-value even though the server did not return it + "expiry": "0001-01-01T00:00:00Z", + }, + }, + { + name: "success when the server does not return a new ID token in the refresh result", + returnIDTok: "", + returnAccessTok: "test-access-token", + returnRefreshTok: "test-refresh-token", + returnTokType: "test-token-type", + returnExpiresIn: "42", + tokenStatusCode: http.StatusOK, + wantToken: &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "test-token-type", + Expiry: time.Now().Add(42 * time.Second), + }, + wantTokenExtras: map[string]interface{}{ + // the library also repeats all the other keys/values returned by the server in the raw extras map + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "token_type": "test-token-type", + "expires_in": "42", + // the library also adds this zero-value even though the server did not return it + "expiry": "0001-01-01T00:00:00Z", + }, + }, + { + name: "server returns an error on token refresh", + tokenStatusCode: http.StatusForbidden, + wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: fake error\n", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.NoError(t, r.ParseForm()) + require.Equal(t, 4, len(r.Form)) + require.Equal(t, "test-client-id", r.Form.Get("client_id")) + require.Equal(t, "test-client-secret", r.Form.Get("client_secret")) + require.Equal(t, "refresh_token", r.Form.Get("grant_type")) + require.Equal(t, "test-initial-refresh-token", r.Form.Get("refresh_token")) + if tt.tokenStatusCode != http.StatusOK { + http.Error(w, "fake error", tt.tokenStatusCode) + return + } + var response struct { + oauth2.Token + IDToken string `json:"id_token,omitempty"` + ExpiresIn string `json:"expires_in,omitempty"` + } + response.IDToken = tt.returnIDTok + response.AccessToken = tt.returnAccessTok + response.RefreshToken = tt.returnRefreshTok + response.TokenType = tt.returnTokType + response.ExpiresIn = tt.returnExpiresIn + w.Header().Set("content-type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(&response)) + })) + t.Cleanup(tokenServer.Close) + + p := ProviderConfig{ + Name: "test-name", + UsernameClaim: "test-username-claim", + GroupsClaim: "test-groups-claim", + Config: &oauth2.Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "https://example.com", + TokenURL: tokenServer.URL, + AuthStyle: oauth2.AuthStyleInParams, + }, + Scopes: []string{"scope1", "scope2"}, + }, + } + + tok, err := p.PerformRefresh( + context.Background(), + "test-initial-refresh-token", + ) + + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + require.Nil(t, tok) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantToken.TokenType, tok.TokenType) + require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken) + require.Equal(t, tt.wantToken.AccessToken, tok.AccessToken) + testutil.RequireTimeInDelta(t, tt.wantToken.Expiry, tok.Expiry, 5*time.Second) + for k, v := range tt.wantTokenExtras { + require.Equal(t, v, tok.Extra(k)) + } + }) + } + }) + t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) { tests := []struct { name string diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 044417ec..3d62e05d 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -808,9 +808,9 @@ func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidcty func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) { h.logger.V(debugLogLevel).Info("Pinniped: Refreshing cached token.") - refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token}) + upstreamOIDCIdentityProvider := h.getProvider(h.oauth2Config, h.provider, h.httpClient) - refreshed, err := refreshSource.Token() + refreshed, err := upstreamOIDCIdentityProvider.PerformRefresh(ctx, refreshToken.Token) if err != nil { // Ignore errors during refresh, but return nil which will trigger the full login flow. return nil, nil @@ -818,7 +818,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - return h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") + return upstreamOIDCIdentityProvider.ValidateToken(ctx, refreshed, "") } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index c63173df..2c30ffb0 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -35,6 +35,7 @@ import ( "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/testlogger" + "go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -404,11 +405,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). Return(&testToken, nil) + mock.EXPECT(). + PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) return mock } @@ -445,11 +452,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). Return(nil, fmt.Errorf("some validation error")) + mock.EXPECT(). + PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token"). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) return mock } @@ -1522,11 +1535,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo }) h.cache = cache - h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). Return(&testToken, nil) + mock.EXPECT(). + PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) return mock }