diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index dfbe785b..fd8b7dd9 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -200,6 +200,21 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PasswordCredentialsGran return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCredentialsGrantAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PasswordCredentialsGrantAndValidateTokens), arg0, arg1, arg2) } +// PerformRefresh mocks base method. +func (m *MockUpstreamOIDCIdentityProviderI) PerformRefresh(arg0 context.Context, arg1 string) (*oauth2.Token, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PerformRefresh", arg0, arg1) + ret0, _ := ret[0].(*oauth2.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PerformRefresh indicates an expected call of PerformRefresh. +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1) +} + // ValidateToken mocks base method. func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (*oidctypes.Token, error) { m.ctrl.T.Helper() diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index b70e0309..6c3c1918 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -210,8 +210,6 @@ func FositeOauth2Helper( // The default is to support all prompt values from the spec. // See https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest - // We'll make a best effort to support these by passing the value of this prompt param to the upstream IDP - // and rely on its implementation of this param. AllowedPromptValues: nil, // Use the fosite default to make it more likely that off the shelf OIDC clients can work with the supervisor. @@ -232,7 +230,7 @@ func FositeOauth2Helper( compose.OpenIDConnectExplicitFactory, compose.OpenIDConnectRefreshFactory, compose.OAuth2PKCEFactory, - TokenExchangeFactory, + TokenExchangeFactory, // handle the "urn:ietf:params:oauth:grant-type:token-exchange" grant type ) provider.(*fosite.Fosite).FormPostHTMLTemplate = formposthtml.Template() return provider diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 6d0af8ec..9ece5f4d 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -63,6 +63,14 @@ type UpstreamOIDCIdentityProviderI interface { redirectURI string, ) (*oidctypes.Token, error) + // PerformRefresh will call the provider's token endpoint to perform a refresh grant. The provider may or may not + // return a new ID or refresh token in the response. If it returns an ID token, then use ValidateRefresh to + // validate the ID token. + PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) + + // ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response + // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated + // tokens, or an error. ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) } diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index ea1d2d62..708d4855 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -130,6 +130,7 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs ) m.providerHandlers[(issuerHostWithPath + oidc.TokenEndpointPath)] = token.NewHandler( + m.upstreamIDPs, oauthHelperWithKubeStorage, ) diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 72222367..724ee0aa 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -5,17 +5,36 @@ package token import ( + "context" "net/http" "github.com/ory/fosite" + "github.com/ory/x/errorsx" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/psession" ) +var ( + errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{ + ErrorField: "error", + DescriptionField: "There was an internal server error.", + HintField: "Required upstream data not found in session.", + CodeField: http.StatusInternalServerError, + } + + errUpstreamRefreshError = &fosite.RFC6749Error{ + ErrorField: "error", + DescriptionField: "Error during upstream refresh.", + CodeField: http.StatusUnauthorized, + } +) + func NewHandler( + idpLister oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { @@ -27,6 +46,20 @@ func NewHandler( return nil } + // Check if we are performing a refresh grant. + if accessRequest.GetGrantTypes().ExactOne("refresh_token") { + // The above call to NewAccessRequest has loaded the session from storage into the accessRequest variable. + // The session, requested scopes, and requested audience from the original authorize request was retrieved + // from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may + // have already been granted on the accessRequest. + err = upstreamRefresh(r.Context(), accessRequest, idpLister) + if err != nil { + plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...) + oauthHelper.WriteAccessError(w, accessRequest, err) + return nil + } + } + accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest) if err != nil { plog.Info("token response error", oidc.FositeErrorForLog(err)...) @@ -39,3 +72,97 @@ func NewHandler( return nil }) } + +func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { + session := accessRequest.GetSession().(*psession.PinnipedSession) + customSessionData := session.Custom + if customSessionData == nil { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + providerName := customSessionData.ProviderName + providerUID := customSessionData.ProviderUID + if providerUID == "" || providerName == "" { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + + switch customSessionData.ProviderType { + case psession.ProviderTypeOIDC: + return upstreamOIDCRefresh(ctx, customSessionData, providerCache) + case psession.ProviderTypeLDAP: + // upstream refresh not yet implemented for LDAP, so do nothing + case psession.ProviderTypeActiveDirectory: + // upstream refresh not yet implemented for AD, so do nothing + default: + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + + return nil +} + +func upstreamOIDCRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error { + if s.OIDC == nil || s.OIDC.UpstreamRefreshToken == "" { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + + p, err := findOIDCProviderByNameAndValidateUID(s, providerCache) + if err != nil { + return err + } + + plog.Debug("attempting upstream refresh request", + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + + refreshedTokens, err := p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) + if err != nil { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh failed using provider %q of type %q.", + s.ProviderName, s.ProviderType).WithWrap(err)) + } + + // Upstream refresh may or may not return a new ID token. From the spec: + // "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token." + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse + _, hasIDTok := refreshedTokens.Extra("id_token").(string) + if hasIDTok { + // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at + // least some providers do not include one, so we skip the nonce validation here (but not other validations). + _, err = p.ValidateToken(ctx, refreshedTokens, "") + if err != nil { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh returned an invalid ID token using provider %q of type %q.", + s.ProviderName, s.ProviderType).WithWrap(err)) + } + } else { + plog.Debug("upstream refresh request did not return a new ID token", + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + } + + // Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in + // the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding + // overwriting the old one. + if refreshedTokens.RefreshToken != "" { + plog.Debug("upstream refresh request did not return a new refresh token", + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + s.OIDC.UpstreamRefreshToken = refreshedTokens.RefreshToken + } + + return nil +} + +func findOIDCProviderByNameAndValidateUID( + s *psession.CustomSessionData, + providerCache oidc.UpstreamIdentityProvidersLister, +) (provider.UpstreamOIDCIdentityProviderI, error) { + for _, p := range providerCache.GetOIDCIdentityProviders() { + if p.GetName() == s.ProviderName { + if p.GetResourceUID() != s.ProviderUID { + return nil, errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Provider %q of type %q from upstream session data has changed its resource UID since authentication.", + s.ProviderName, s.ProviderType)) + } + return p, nil + } + } + return nil, errorsx.WithStack(errUpstreamRefreshError. + WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) +} diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index 262f5756..968944d7 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -23,12 +23,13 @@ import ( "time" "github.com/ory/fosite" - "github.com/ory/fosite/handler/oauth2" + fositeoauth2 "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/pkce" "github.com/ory/fosite/token/jwt" "github.com/pkg/errors" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" josejwt "gopkg.in/square/go-jose.v2/jwt" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -44,8 +45,10 @@ import ( "go.pinniped.dev/internal/fositestorage/refreshtoken" "go.pinniped.dev/internal/fositestoragei" "go.pinniped.dev/internal/here" + "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/jwks" + "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" @@ -179,6 +182,13 @@ var ( } `) + pinnipedUpstreamSessionDataNotFoundErrorBody = here.Doc(` + { + "error": "error", + "error_description": "There was an internal server error. Required upstream data not found in session." + } + `) + happyAuthRequest = &http.Request{ Form: url.Values{ "response_type": {"code"}, @@ -206,12 +216,25 @@ var ( } ) +type expectedUpstreamRefresh struct { + performedByUpstreamName string + args *oidctestutil.PerformRefreshArgs +} + +type expectedUpstreamValidateTokens struct { + performedByUpstreamName string + args *oidctestutil.ValidateTokenArgs +} + type tokenEndpointResponseExpectedValues struct { - wantStatus int - wantSuccessBodyFields []string - wantErrorResponseBody string - wantRequestedScopes []string - wantGrantedScopes []string + wantStatus int + wantSuccessBodyFields []string + wantErrorResponseBody string + wantRequestedScopes []string + wantGrantedScopes []string + wantUpstreamOIDCRefreshCall *expectedUpstreamRefresh + wantUpstreamOIDCValidateTokenCall *expectedUpstreamValidateTokens + wantCustomSessionDataStored *psession.CustomSessionData } type authcodeExchangeInputs struct { @@ -222,13 +245,9 @@ type authcodeExchangeInputs struct { s fositestoragei.AllFositeStorage, authCode string, ) - makeOathHelper func( - t *testing.T, - authRequest *http.Request, - store fositestoragei.AllFositeStorage, - ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) - - want tokenEndpointResponseExpectedValues + makeOathHelper OauthHelperFactoryFunc + customSessionData *psession.CustomSessionData + want tokenEndpointResponseExpectedValues } func TestTokenEndpoint(t *testing.T) { @@ -520,7 +539,8 @@ func TestTokenEndpoint(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - exchangeAuthcodeForTokens(t, test.authcodeExchange) + // Authcode exchange doesn't use the upstream provider cache, so just pass an empty cache. + exchangeAuthcodeForTokens(t, test.authcodeExchange, oidctestutil.NewUpstreamIDPListerBuilder().Build()) }) } } @@ -549,7 +569,9 @@ func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) { t.Parallel() // First call - should be successful. - subject, rsp, authCode, _, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange) + // Authcode exchange doesn't use the upstream provider cache, so just pass an empty cache. + subject, rsp, authCode, _, secrets, oauthStore := exchangeAuthcodeForTokens(t, + test.authcodeExchange, oidctestutil.NewUpstreamIDPListerBuilder().Build()) var parsedResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody)) @@ -573,9 +595,11 @@ func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) { requireInvalidAccessTokenStorage(t, parsedResponseBody, oauthStore) // This was previously invalidated by the first request, so it remains invalidated requireInvalidPKCEStorage(t, authCode, oauthStore) - // Fosite never cleans up OpenID Connect session storage, so it is still there + // Fosite never cleans up OpenID Connect session storage, so it is still there. + // Note that customSessionData is only relevant to refresh grant, so we leave it as nil for this + // authcode exchange test, even though in practice it would actually be in the session. requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, - test.authcodeExchange.want.wantRequestedScopes, test.authcodeExchange.want.wantGrantedScopes) + test.authcodeExchange.want.wantRequestedScopes, test.authcodeExchange.want.wantGrantedScopes, nil) // Check that the access token and refresh token storage were both deleted, and the number of other storage objects did not change. testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) @@ -602,6 +626,7 @@ func TestTokenExchange(t *testing.T) { }, want: successfulAuthCodeExchange, } + tests := []struct { name string @@ -742,7 +767,9 @@ func TestTokenExchange(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() - subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange) + // Authcode exchange doesn't use the upstream provider cache, so just pass an empty cache. + subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, + test.authcodeExchange, oidctestutil.NewUpstreamIDPListerBuilder().Build()) var parsedAuthcodeExchangeResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody)) @@ -845,80 +872,219 @@ type refreshRequestInputs struct { } func TestRefreshGrant(t *testing.T) { + const ( + oidcUpstreamName = "some-oidc-idp" + oidcUpstreamResourceUID = "oidc-resource-uid" + oidcUpstreamType = "oidc" + oidcUpstreamInitialRefreshToken = "initial-upstream-refresh-token" + oidcUpstreamRefreshedIDToken = "fake-refreshed-id-token" + oidcUpstreamRefreshedRefreshToken = "fake-refreshed-refresh-token" + ) + + // The below values are funcs so every test can have its own copy of the objects, to avoid data races + // in these parallel tests. + + upstreamOIDCIdentityProviderBuilder := func() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { + return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName(oidcUpstreamName). + WithResourceUID(oidcUpstreamResourceUID) + } + + initialUpstreamOIDCCustomSessionData := func() *psession.CustomSessionData { + return &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamRefreshToken: oidcUpstreamInitialRefreshToken, + }, + } + } + + upstreamOIDCCustomSessionDataWithNewRefreshToken := func(newRefreshToken string) *psession.CustomSessionData { + sessionData := initialUpstreamOIDCCustomSessionData() + sessionData.OIDC.UpstreamRefreshToken = newRefreshToken + return sessionData + } + + happyUpstreamRefreshCall := func() *expectedUpstreamRefresh { + return &expectedUpstreamRefresh{ + performedByUpstreamName: oidcUpstreamName, + args: &oidctestutil.PerformRefreshArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + RefreshToken: oidcUpstreamInitialRefreshToken, + }, + } + } + + happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token) *expectedUpstreamValidateTokens { + return &expectedUpstreamValidateTokens{ + performedByUpstreamName: oidcUpstreamName, + args: &oidctestutil.ValidateTokenArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + Tok: expectedTokens, + ExpectedIDTokenNonce: "", // always expect empty string + }, + } + } + + happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess := func(wantCustomSessionDataStored *psession.CustomSessionData) tokenEndpointResponseExpectedValues { + want := tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access"}, + wantGrantedScopes: []string{"openid", "offline_access"}, + wantCustomSessionDataStored: wantCustomSessionDataStored, + } + return want + } + + happyRefreshTokenResponseForOpenIDAndOfflineAccess := func(wantCustomSessionDataStored *psession.CustomSessionData, expectToValidateToken *oauth2.Token) tokenEndpointResponseExpectedValues { + // Should always have some custom session data stored. The other expectations happens to be the + // same as the same values as the authcode exchange case. + want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored) + // Should always try to perform an upstream refresh. + want.wantUpstreamOIDCRefreshCall = happyUpstreamRefreshCall() + // Should only try to ValidateToken when there was an id token returned by the upstream refresh. + if expectToValidateToken != nil { + want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken) + } + return want + } + + refreshedUpstreamTokensWithRefreshTokenWithoutIDToken := func() *oauth2.Token { + return &oauth2.Token{ + AccessToken: "fake-refreshed-access-token", + TokenType: "Bearer", + RefreshToken: oidcUpstreamRefreshedRefreshToken, + Expiry: time.Date(2050, 1, 1, 1, 1, 1, 1, time.UTC), + } + } + + refreshedUpstreamTokensWithIDAndRefreshTokens := func() *oauth2.Token { + return refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(). + WithExtra(map[string]interface{}{"id_token": oidcUpstreamRefreshedIDToken}) + } + + refreshedUpstreamTokensWithIDTokenWithoutRefreshToken := func() *oauth2.Token { + tokens := refreshedUpstreamTokensWithIDAndRefreshTokens() + tokens.RefreshToken = "" // remove the refresh token + return tokens + } + tests := []struct { name string + idps *oidctestutil.UpstreamIDPListerBuilder authcodeExchange authcodeExchangeInputs refreshRequest refreshRequestInputs }{ { - name: "happy path refresh grant with ID token", + name: "happy path refresh grant with openid scope granted (id token returned)", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), }, refreshRequest: refreshRequestInputs{ - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }}, + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + refreshedUpstreamTokensWithIDAndRefreshTokens(), + ), + }, }, { - name: "happy path refresh grant without ID token", + name: "happy path refresh grant without openid scope granted (no id token returned)", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, - }}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + }, + }, + }, + { + name: "happy path refresh grant when the upstream refresh does not return a new ID token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + nil, // expect ValidateToken is *not* called + ), + }, + }, + { + name: "happy path refresh grant when the upstream refresh does not return a new refresh token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDTokenWithoutRefreshToken()).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + initialUpstreamOIDCCustomSessionData(), // still has the initial refresh token stored + refreshedUpstreamTokensWithIDTokenWithoutRefreshToken(), + ), + }, }, { name: "when the refresh request adds a new scope to the list of requested scopes then it is ignored", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), }, refreshRequest: refreshRequestInputs{ modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { r.Body = happyRefreshRequestBody(refreshToken).WithScope("openid some-other-scope-not-from-auth-request").ReadCloser() }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }}, + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + refreshedUpstreamTokensWithIDAndRefreshTokens(), + ), + }, }, { name: "when the refresh request removes a scope which was originally granted from the list of requested scopes then it is granted anyway", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access pinniped:request-audience") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -926,43 +1092,47 @@ func TestRefreshGrant(t *testing.T) { r.Body = happyRefreshRequestBody(refreshToken).WithScope("openid").ReadCloser() // do not ask for "pinniped:request-audience" again }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - }}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + }, + }, }, { name: "when the refresh request does not include a scope param then it gets all the same scopes as the original authorization request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), }, refreshRequest: refreshRequestInputs{ modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { r.Body = happyRefreshRequestBody(refreshToken).WithScope("").ReadCloser() }, - want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"openid", "offline_access"}, - wantGrantedScopes: []string{"openid", "offline_access"}, - }}, + want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( + upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + refreshedUpstreamTokensWithIDAndRefreshTokens(), + ), + }, }, { name: "when a bad refresh token is sent in the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -972,17 +1142,21 @@ func TestRefreshGrant(t *testing.T) { want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusBadRequest, wantErrorResponseBody: fositeInvalidAuthCodeErrorBody, - }}, + }, + }, }, { name: "when the access token is sent as if it were a refresh token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -992,17 +1166,21 @@ func TestRefreshGrant(t *testing.T) { want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusBadRequest, wantErrorResponseBody: fositeInvalidAuthCodeErrorBody, - }}, + }, + }, }, { name: "when the wrong client ID is included in the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusOK, - wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, - wantRequestedScopes: []string{"offline_access"}, - wantGrantedScopes: []string{"offline_access"}, + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"offline_access"}, + wantGrantedScopes: []string{"offline_access"}, + wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1012,7 +1190,301 @@ func TestRefreshGrant(t *testing.T) { want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusUnauthorized, wantErrorResponseBody: fositeInvalidClientErrorBody, - }}, + }, + }, + }, + { + name: "when there is no custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: nil, // this should not happen in practice + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(nil), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no provider name in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: "", // this should not happen in practice + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: "", // this should not happen in practice + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no provider UID in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: "", // this should not happen in practice + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: "", // this should not happen in practice + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no provider type in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is an illegal provider type in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "not-an-allowed-provider-type", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: "not-an-allowed-provider-type", // this should not happen in practice + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no OIDC-specific data in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: nil, // this should not happen in practice + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: nil, // this should not happen in practice + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when there is no OIDC refresh token in custom session data found in the session storage during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamRefreshToken: "", // this should not happen in practice + }, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamRefreshToken: "", // this should not happen in practice + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: pinnipedUpstreamSessionDataNotFoundErrorBody, + }, + }, + }, + { + name: "when the provider in the session storage is not found due to its name during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: "this-name-will-not-be-found", // this could happen if the OIDCIdentityProvider was deleted since original login + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: "this-name-will-not-be-found", // this could happen if the OIDCIdentityProvider was deleted since original login + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Provider 'this-name-will-not-be-found' of type 'oidc' from upstream session data was not found." + } + `), + }, + }, + }, + { + name: "when the provider in the session storage is found but has the wrong resource UID during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: "this is the wrong uid", // this could happen if the OIDCIdentityProvider was deleted and recreated at the same name since original login + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ // want the initial customSessionData to be unmodified + ProviderName: oidcUpstreamName, + ProviderUID: "this is the wrong uid", // this could happen if the OIDCIdentityProvider was deleted and recreated at the same name since original login + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{UpstreamRefreshToken: oidcUpstreamInitialRefreshToken}, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Provider 'some-oidc-idp' of type 'oidc' from upstream session data has changed its resource UID since authentication." + } + `), + }, + }, + }, + { + name: "when the upstream refresh fails during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). + WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed using provider 'some-oidc-idp' of type 'oidc'." + } + `), + }, + }, + }, + { + name: "when the upstream refresh returns an invalid ID token during the refresh request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). + WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). + // This is the current format of the errors returned by the production code version of ValidateToken, see ValidateToken in upstreamoidc.go + WithValidateTokenError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))). + Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh returned an invalid ID token using provider 'some-oidc-idp' of type 'oidc'." + } + `), + }, + }, }, } for _, test := range tests { @@ -1021,10 +1493,15 @@ func TestRefreshGrant(t *testing.T) { t.Parallel() // First exchange the authcode for tokens, including a refresh token. - subject, rsp, authCode, jwtSigningKey, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange) + subject, rsp, authCode, jwtSigningKey, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange, test.idps.Build()) var parsedAuthcodeExchangeResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody)) + // Performing an authcode exchange should not have caused any upstream refresh, which should only + // happen during a downstream refresh. + test.idps.RequireExactlyZeroCallsToPerformRefresh(t) + test.idps.RequireExactlyZeroCallsToValidateToken(t) + // Wait one second before performing the refresh so we can see that the refreshed ID token has new issued // at and expires at dates which are newer than the old tokens. // If this gets too annoying in terms of making our test suite slower then we can remove it and adjust @@ -1033,8 +1510,10 @@ func TestRefreshGrant(t *testing.T) { // Send the refresh token back and preform a refresh. firstRefreshToken := parsedAuthcodeExchangeResponseBody["refresh_token"].(string) + require.NotEmpty(t, firstRefreshToken) + reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") req := httptest.NewRequest("POST", "/path/shouldn't/matter", - happyRefreshRequestBody(firstRefreshToken).ReadCloser()) + happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") if test.refreshRequest.modifyTokenRequest != nil { test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string)) @@ -1045,11 +1524,45 @@ func TestRefreshGrant(t *testing.T) { t.Logf("second response: %#v", refreshResponse) t.Logf("second response body: %q", refreshResponse.Body.String()) + // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. + if test.refreshRequest.want.wantUpstreamOIDCRefreshCall != nil { + test.refreshRequest.want.wantUpstreamOIDCRefreshCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToPerformRefresh(t, + test.refreshRequest.want.wantUpstreamOIDCRefreshCall.performedByUpstreamName, + test.refreshRequest.want.wantUpstreamOIDCRefreshCall.args, + ) + } else { + test.idps.RequireExactlyZeroCallsToPerformRefresh(t) + } + + // Test that we did or did not make a call to the upstream OIDC provider interface to validate the + // new ID token that was returned by the upstream refresh. + if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil { + test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToValidateToken(t, + test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName, + test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args, + ) + } else { + test.idps.RequireExactlyZeroCallsToValidateToken(t) + } + // The bug in fosite that prevents at_hash from appearing in the initial ID token does not impact the refreshed ID token wantAtHashClaimInIDToken := true // Refreshed ID tokens do not include the nonce from the original auth request wantNonceValueInIDToken := false - requireTokenEndpointBehavior(t, test.refreshRequest.want, wantAtHashClaimInIDToken, wantNonceValueInIDToken, refreshResponse, authCode, oauthStore, jwtSigningKey, secrets) + + requireTokenEndpointBehavior(t, + test.refreshRequest.want, + test.authcodeExchange.customSessionData, + wantAtHashClaimInIDToken, + wantNonceValueInIDToken, + refreshResponse, + authCode, + oauthStore, + jwtSigningKey, + secrets, + ) if test.refreshRequest.want.wantStatus == http.StatusOK { wantIDToken := contains(test.refreshRequest.want.wantSuccessBodyFields, "id_token") @@ -1109,7 +1622,7 @@ func requireClaimsAreEqual(t *testing.T, claimName string, claimsOfTokenA map[st require.Equal(t, claimsOfTokenA[claimName], claimsOfTokenB[claimName]) } -func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( +func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps provider.DynamicUpstreamIDPProvider) ( subject http.Handler, rsp *httptest.ResponseRecorder, authCode string, @@ -1129,15 +1642,17 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( oauthStore = oidc.NewKubeStorage(secrets, oidc.DefaultOIDCTimeoutsConfiguration()) if test.makeOathHelper != nil { - oauthHelper, authCode, jwtSigningKey = test.makeOathHelper(t, authRequest, oauthStore) + oauthHelper, authCode, jwtSigningKey = test.makeOathHelper(t, authRequest, oauthStore, test.customSessionData) } else { - oauthHelper, authCode, jwtSigningKey = makeHappyOauthHelper(t, authRequest, oauthStore) + // Note that makeHappyOauthHelper() calls simulateAuthEndpointHavingAlreadyRun() to preload the session storage. + oauthHelper, authCode, jwtSigningKey = makeHappyOauthHelper(t, authRequest, oauthStore, test.customSessionData) } if test.modifyStorage != nil { test.modifyStorage(t, oauthStore, authCode) } - subject = NewHandler(oauthHelper) + + subject = NewHandler(idps, oauthHelper) authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid") expectedNumberOfIDSessionsStored := 0 @@ -1163,7 +1678,18 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( wantAtHashClaimInIDToken := false // due to a bug in fosite, the at_hash claim is not filled in during authcode exchange wantNonceValueInIDToken := true // ID tokens returned by the authcode exchange must include the nonce from the auth request (unliked refreshed ID tokens) - requireTokenEndpointBehavior(t, test.want, wantAtHashClaimInIDToken, wantNonceValueInIDToken, rsp, authCode, oauthStore, jwtSigningKey, secrets) + + requireTokenEndpointBehavior(t, + test.want, + test.customSessionData, + wantAtHashClaimInIDToken, + wantNonceValueInIDToken, + rsp, + authCode, + oauthStore, + jwtSigningKey, + secrets, + ) return subject, rsp, authCode, jwtSigningKey, secrets, oauthStore } @@ -1171,6 +1697,7 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) ( func requireTokenEndpointBehavior( t *testing.T, test tokenEndpointResponseExpectedValues, + oldCustomSessionData *psession.CustomSessionData, wantAtHashClaimInIDToken bool, wantNonceValueInIDToken bool, tokenEndpointResponse *httptest.ResponseRecorder, @@ -1193,9 +1720,10 @@ func requireTokenEndpointBehavior( wantRefreshToken := contains(test.wantSuccessBodyFields, "refresh_token") requireInvalidAuthCodeStorage(t, authCode, oauthStore, secrets) - requireValidAccessTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, secrets) + requireValidAccessTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, test.wantCustomSessionDataStored, secrets) requireInvalidPKCEStorage(t, authCode, oauthStore) - requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes) + // Performing a refresh does not update the OIDC storage, so after a refresh it should still have the old custom session data from the initial login. + requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, oldCustomSessionData) expectedNumberOfRefreshTokenSessionsStored := 0 if wantRefreshToken { @@ -1207,7 +1735,7 @@ func requireTokenEndpointBehavior( requireValidIDToken(t, parsedResponseBody, jwtSigningKey, wantAtHashClaimInIDToken, wantNonceValueInIDToken, parsedResponseBody["access_token"].(string)) } if wantRefreshToken { - requireValidRefreshTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, secrets) + requireValidRefreshTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes, test.wantCustomSessionDataStored, secrets) } testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) @@ -1304,16 +1832,24 @@ func getFositeDataSignature(t *testing.T, data string) string { return split[1] } +type OauthHelperFactoryFunc func( + t *testing.T, + authRequest *http.Request, + store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, +) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) + func makeHappyOauthHelper( t *testing.T, authRequest *http.Request, store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { t.Helper() jwtSigningKey, jwkProvider := generateJWTSigningKeyAndJWKSProvider(t, goodIssuer) oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, hmacSecretFunc, jwkProvider, oidc.DefaultOIDCTimeoutsConfiguration()) - authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper) + authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper, initialCustomSessionData) return oauthHelper, authResponder.GetCode(), jwtSigningKey } @@ -1334,12 +1870,13 @@ func makeOauthHelperWithJWTKeyThatWorksOnlyOnce( t *testing.T, authRequest *http.Request, store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { t.Helper() jwtSigningKey, jwkProvider := generateJWTSigningKeyAndJWKSProvider(t, goodIssuer) oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, hmacSecretFunc, &singleUseJWKProvider{DynamicJWKSProvider: jwkProvider}, oidc.DefaultOIDCTimeoutsConfiguration()) - authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper) + authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper, initialCustomSessionData) return oauthHelper, authResponder.GetCode(), jwtSigningKey } @@ -1347,17 +1884,23 @@ func makeOauthHelperWithNilPrivateJWTSigningKey( t *testing.T, authRequest *http.Request, store fositestoragei.AllFositeStorage, + initialCustomSessionData *psession.CustomSessionData, ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { t.Helper() jwkProvider := jwks.NewDynamicJWKSProvider() // empty provider which contains no signing key for this issuer oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, hmacSecretFunc, jwkProvider, oidc.DefaultOIDCTimeoutsConfiguration()) - authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper) + authResponder := simulateAuthEndpointHavingAlreadyRun(t, authRequest, oauthHelper, initialCustomSessionData) return oauthHelper, authResponder.GetCode(), nil } // Simulate the auth endpoint running so Fosite code will fill the store with realistic values. -func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Request, oauthHelper fosite.OAuth2Provider) fosite.AuthorizeResponder { +func simulateAuthEndpointHavingAlreadyRun( + t *testing.T, + authRequest *http.Request, + oauthHelper fosite.OAuth2Provider, + initialCustomSessionData *psession.CustomSessionData, +) fosite.AuthorizeResponder { // We only set the fields in the session that Fosite wants us to set. ctx := context.Background() session := &psession.PinnipedSession{ @@ -1374,11 +1917,7 @@ func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Reques Subject: "", // not used, note that callback_handler.go does not set this Username: "", // not used, note that callback_handler.go does not set this }, - Custom: &psession.CustomSessionData{ - OIDC: &psession.OIDCSessionData{ - UpstreamRefreshToken: "starting-fake-refresh-token", - }, - }, + Custom: initialCustomSessionData, } authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest) require.NoError(t, err) @@ -1416,7 +1955,7 @@ func generateJWTSigningKeyAndJWKSProvider(t *testing.T, issuer string) (*ecdsa.P func requireInvalidAuthCodeStorage( t *testing.T, code string, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, secrets v1.SecretInterface, ) { t.Helper() @@ -1431,9 +1970,10 @@ func requireInvalidAuthCodeStorage( func requireValidRefreshTokenStorage( t *testing.T, body map[string]interface{}, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, wantRequestedScopes []string, wantGrantedScopes []string, + wantCustomSessionData *psession.CustomSessionData, secrets v1.SecretInterface, ) { t.Helper() @@ -1455,6 +1995,7 @@ func requireValidRefreshTokenStorage( wantRequestedScopes, wantGrantedScopes, true, + wantCustomSessionData, ) requireGarbageCollectTimeInDelta(t, refreshTokenString, "refresh-token", secrets, time.Now().Add(9*time.Hour).Add(2*time.Minute), 1*time.Minute) @@ -1463,9 +2004,10 @@ func requireValidRefreshTokenStorage( func requireValidAccessTokenStorage( t *testing.T, body map[string]interface{}, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, wantRequestedScopes []string, wantGrantedScopes []string, + wantCustomSessionData *psession.CustomSessionData, secrets v1.SecretInterface, ) { t.Helper() @@ -1506,6 +2048,7 @@ func requireValidAccessTokenStorage( wantRequestedScopes, wantGrantedScopes, true, + wantCustomSessionData, ) requireGarbageCollectTimeInDelta(t, accessTokenString, "access-token", secrets, time.Now().Add(9*time.Hour).Add(2*time.Minute), 1*time.Minute) @@ -1514,7 +2057,7 @@ func requireValidAccessTokenStorage( func requireInvalidAccessTokenStorage( t *testing.T, body map[string]interface{}, - storage oauth2.CoreStorage, + storage fositeoauth2.CoreStorage, ) { t.Helper() @@ -1547,6 +2090,7 @@ func requireValidOIDCStorage( storage openid.OpenIDConnectRequestStorage, wantRequestedScopes []string, wantGrantedScopes []string, + wantCustomSessionData *psession.CustomSessionData, ) { t.Helper() @@ -1569,6 +2113,7 @@ func requireValidOIDCStorage( wantRequestedScopes, wantGrantedScopes, false, + wantCustomSessionData, ) } else { _, err := storage.GetOpenIDConnectSession(context.Background(), code, nil) @@ -1583,6 +2128,7 @@ func requireValidStoredRequest( wantRequestedScopes []string, wantGrantedScopes []string, wantAccessTokenExpiresAt bool, + wantCustomSessionData *psession.CustomSessionData, ) { t.Helper() @@ -1669,6 +2215,9 @@ func requireValidStoredRequest( // We don't use these, so they should be empty. require.Empty(t, session.Fosite.Username) require.Empty(t, session.Fosite.Subject) + + // The custom session data was stored as expected. + require.Equal(t, wantCustomSessionData, session.Custom) } func requireGarbageCollectTimeInDelta(t *testing.T, tokenString string, typeLabel string, secrets v1.SecretInterface, wantExpirationTime time.Time, deltaTime time.Duration) { diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 3f8c3f57..90361ddd 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -58,6 +58,21 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct { Password string } +// PerformRefreshArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). +type PerformRefreshArgs struct { + Ctx context.Context + RefreshToken string +} + +// ValidateTokenArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.ValidateTokenFunc(). +type ValidateTokenArgs struct { + Ctx context.Context + Tok *oauth2.Token + ExpectedIDTokenNonce nonce.Nonce +} + type TestUpstreamLDAPIdentityProvider struct { Name string ResourceUID types.UID @@ -107,10 +122,18 @@ type TestUpstreamOIDCIdentityProvider struct { password string, ) (*oidctypes.Token, error) + PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error) + + ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) + exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs passwordCredentialsGrantAndValidateTokensCallCount int passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs + performRefreshCallCount int + performRefreshArgs []*PerformRefreshArgs + validateTokenCallCount int + validateTokenArgs []*ValidateTokenArgs } var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} @@ -193,8 +216,51 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs return u.exchangeAuthcodeAndValidateTokensArgs[call] } -func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(_ context.Context, _ *oauth2.Token, _ nonce.Nonce) (*oidctypes.Token, error) { - panic("implement me") +func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + if u.performRefreshArgs == nil { + u.performRefreshArgs = make([]*PerformRefreshArgs, 0) + } + u.performRefreshCallCount++ + u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ + Ctx: ctx, + RefreshToken: refreshToken, + }) + return u.PerformRefreshFunc(ctx, refreshToken) +} + +func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int { + return u.performRefreshCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs { + if u.performRefreshArgs == nil { + u.performRefreshArgs = make([]*PerformRefreshArgs, 0) + } + return u.performRefreshArgs[call] +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.validateTokenArgs == nil { + u.validateTokenArgs = make([]*ValidateTokenArgs, 0) + } + u.validateTokenCallCount++ + u.validateTokenArgs = append(u.validateTokenArgs, &ValidateTokenArgs{ + Ctx: ctx, + Tok: tok, + ExpectedIDTokenNonce: expectedIDTokenNonce, + }) + return u.ValidateTokenFunc(ctx, tok, expectedIDTokenNonce) +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenCallCount() int { + return u.validateTokenCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenArgs(call int) *ValidateTokenArgs { + if u.validateTokenArgs == nil { + u.validateTokenArgs = make([]*ValidateTokenArgs, 0) + } + return u.validateTokenArgs[call] } type UpstreamIDPListerBuilder struct { @@ -316,6 +382,80 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV ) } +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *PerformRefreshArgs, +) { + t.Helper() + var actualArgs *PerformRefreshArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount + actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualArgs = upstreamOIDC.performRefreshArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, + "should have been exactly one call to PerformRefresh() by all OIDC upstreams", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "PerformRefresh() was called on the wrong OIDC upstream", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) { + t.Helper() + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.performRefreshCallCount + } + require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + "expected exactly zero calls to PerformRefresh()", + ) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *ValidateTokenArgs, +) { + t.Helper() + var actualArgs *ValidateTokenArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstreamOIDC.validateTokenCallCount + actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualArgs = upstreamOIDC.validateTokenArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, + "should have been exactly one call to ValidateToken() by all OIDC upstreams", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "ValidateToken() was called on the wrong OIDC upstream", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *testing.T) { + t.Helper() + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount + } + require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + "expected exactly zero calls to ValidateToken()", + ) +} + func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder { return &UpstreamIDPListerBuilder{} } @@ -329,11 +469,15 @@ type TestUpstreamOIDCIdentityProviderBuilder struct { refreshToken *oidctypes.RefreshToken usernameClaim string groupsClaim string + refreshedTokens *oauth2.Token + validatedTokens *oidctypes.Token authorizationURL url.URL additionalAuthcodeParams map[string]string allowPasswordGrant bool authcodeExchangeErr error passwordGrantErr error + performRefreshErr error + validateTokenErr error } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder { @@ -429,6 +573,26 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPasswordGrantError(err err return u } +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRefreshedTokens(tokens *oauth2.Token) *TestUpstreamOIDCIdentityProviderBuilder { + u.refreshedTokens = tokens + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.performRefreshErr = err + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder { + u.validatedTokens = tokens + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.validateTokenErr = err + return u +} + func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdentityProvider { return &TestUpstreamOIDCIdentityProvider{ Name: u.name, @@ -452,6 +616,18 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent } return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}, RefreshToken: u.refreshToken}, nil }, + PerformRefreshFunc: func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + if u.performRefreshErr != nil { + return nil, u.performRefreshErr + } + return u.refreshedTokens, nil + }, + ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.validateTokenErr != nil { + return nil, u.validateTokenErr + } + return u.validatedTokens, nil + }, } } diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 17e70915..34c27be2 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -47,6 +47,8 @@ type ProviderConfig struct { } } +var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil) + func (p *ProviderConfig) GetResourceUID() types.UID { return p.ResourceUID } @@ -120,6 +122,14 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, return p.ValidateToken(ctx, tok, expectedIDTokenNonce) } +func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Create a TokenSource without an access token, so it thinks that a refresh is immediately required. + // Then ask it for the tokens to cause it to perform the refresh and return the results. + return p.Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}).Token() +} + +// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response, +// if the provider offers the userinfo endpoint. func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { @@ -146,7 +156,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e } maybeLogClaims("claims from ID token", p.Name, validatedClaims) - if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil { + if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims); err != nil { return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err) } @@ -167,7 +177,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e }, nil } -func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error { +func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error { idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string) if len(idTokenSubject) == 0 { return nil // defer to existing ID token validation diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 2ffd40ca..342683fc 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -23,6 +23,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" ) @@ -288,6 +289,171 @@ func TestProviderConfig(t *testing.T) { } }) + t.Run("PerformRefresh", func(t *testing.T) { + tests := []struct { + name string + returnIDTok string + returnAccessTok string + returnRefreshTok string + returnTokType string + returnExpiresIn string + tokenStatusCode int + + wantErr string + wantToken *oauth2.Token + wantTokenExtras map[string]interface{} + }{ + { + name: "success when the server returns all tokens in the refresh result", + returnIDTok: "test-id-token", + returnAccessTok: "test-access-token", + returnRefreshTok: "test-refresh-token", + returnTokType: "test-token-type", + returnExpiresIn: "42", + tokenStatusCode: http.StatusOK, + wantToken: &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "test-token-type", + Expiry: time.Now().Add(42 * time.Second), + }, + wantTokenExtras: map[string]interface{}{ + // the ID token only appears in the extras map + "id_token": "test-id-token", + // the library also repeats all the other keys/values returned by the server in the raw extras map + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "token_type": "test-token-type", + "expires_in": "42", + // the library also adds this zero-value even though the server did not return it + "expiry": "0001-01-01T00:00:00Z", + }, + }, + { + name: "success when the server does not return a new refresh token in the refresh result", + returnIDTok: "test-id-token", + returnAccessTok: "test-access-token", + returnRefreshTok: "", + returnTokType: "test-token-type", + returnExpiresIn: "42", + tokenStatusCode: http.StatusOK, + wantToken: &oauth2.Token{ + AccessToken: "test-access-token", + // the library sets the original refresh token into the result, even though the server did not return that + RefreshToken: "test-initial-refresh-token", + TokenType: "test-token-type", + Expiry: time.Now().Add(42 * time.Second), + }, + wantTokenExtras: map[string]interface{}{ + // the ID token only appears in the extras map + "id_token": "test-id-token", + // the library also repeats all the other keys/values returned by the server in the raw extras map + "access_token": "test-access-token", + "token_type": "test-token-type", + "expires_in": "42", + // the library also adds this zero-value even though the server did not return it + "expiry": "0001-01-01T00:00:00Z", + }, + }, + { + name: "success when the server does not return a new ID token in the refresh result", + returnIDTok: "", + returnAccessTok: "test-access-token", + returnRefreshTok: "test-refresh-token", + returnTokType: "test-token-type", + returnExpiresIn: "42", + tokenStatusCode: http.StatusOK, + wantToken: &oauth2.Token{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + TokenType: "test-token-type", + Expiry: time.Now().Add(42 * time.Second), + }, + wantTokenExtras: map[string]interface{}{ + // the library also repeats all the other keys/values returned by the server in the raw extras map + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "token_type": "test-token-type", + "expires_in": "42", + // the library also adds this zero-value even though the server did not return it + "expiry": "0001-01-01T00:00:00Z", + }, + }, + { + name: "server returns an error on token refresh", + tokenStatusCode: http.StatusForbidden, + wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: fake error\n", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.NoError(t, r.ParseForm()) + require.Equal(t, 4, len(r.Form)) + require.Equal(t, "test-client-id", r.Form.Get("client_id")) + require.Equal(t, "test-client-secret", r.Form.Get("client_secret")) + require.Equal(t, "refresh_token", r.Form.Get("grant_type")) + require.Equal(t, "test-initial-refresh-token", r.Form.Get("refresh_token")) + if tt.tokenStatusCode != http.StatusOK { + http.Error(w, "fake error", tt.tokenStatusCode) + return + } + var response struct { + oauth2.Token + IDToken string `json:"id_token,omitempty"` + ExpiresIn string `json:"expires_in,omitempty"` + } + response.IDToken = tt.returnIDTok + response.AccessToken = tt.returnAccessTok + response.RefreshToken = tt.returnRefreshTok + response.TokenType = tt.returnTokType + response.ExpiresIn = tt.returnExpiresIn + w.Header().Set("content-type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(&response)) + })) + t.Cleanup(tokenServer.Close) + + p := ProviderConfig{ + Name: "test-name", + UsernameClaim: "test-username-claim", + GroupsClaim: "test-groups-claim", + Config: &oauth2.Config{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + Endpoint: oauth2.Endpoint{ + AuthURL: "https://example.com", + TokenURL: tokenServer.URL, + AuthStyle: oauth2.AuthStyleInParams, + }, + Scopes: []string{"scope1", "scope2"}, + }, + } + + tok, err := p.PerformRefresh( + context.Background(), + "test-initial-refresh-token", + ) + + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + require.Nil(t, tok) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantToken.TokenType, tok.TokenType) + require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken) + require.Equal(t, tt.wantToken.AccessToken, tok.AccessToken) + testutil.RequireTimeInDelta(t, tt.wantToken.Expiry, tok.Expiry, 5*time.Second) + for k, v := range tt.wantTokenExtras { + require.Equal(t, v, tok.Extra(k)) + } + }) + } + }) + t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) { tests := []struct { name string diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 044417ec..3d62e05d 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -808,9 +808,9 @@ func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidcty func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) { h.logger.V(debugLogLevel).Info("Pinniped: Refreshing cached token.") - refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token}) + upstreamOIDCIdentityProvider := h.getProvider(h.oauth2Config, h.provider, h.httpClient) - refreshed, err := refreshSource.Token() + refreshed, err := upstreamOIDCIdentityProvider.PerformRefresh(ctx, refreshToken.Token) if err != nil { // Ignore errors during refresh, but return nil which will trigger the full login flow. return nil, nil @@ -818,7 +818,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - return h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") + return upstreamOIDCIdentityProvider.ValidateToken(ctx, refreshed, "") } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index c63173df..2c30ffb0 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -35,6 +35,7 @@ import ( "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/testlogger" + "go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -404,11 +405,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). Return(&testToken, nil) + mock.EXPECT(). + PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) return mock } @@ -445,11 +452,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). Return(nil, fmt.Errorf("some validation error")) + mock.EXPECT(). + PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token"). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) return mock } @@ -1522,11 +1535,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo }) h.cache = cache - h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). Return(&testToken, nil) + mock.EXPECT(). + PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). + DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { + // Call the real production code to perform a refresh. + return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken) + }) return mock }