diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index 6467b4ae..c27c0184 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -244,16 +244,16 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0 } // ValidateTokenAndMergeWithUserInfo mocks base method. -func (m *MockUpstreamOIDCIdentityProviderI) ValidateTokenAndMergeWithUserInfo(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce, arg3 bool) (*oidctypes.Token, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ValidateTokenAndMergeWithUserInfo(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce, arg3, arg4 bool) (*oidctypes.Token, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateTokenAndMergeWithUserInfo", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ValidateTokenAndMergeWithUserInfo", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(*oidctypes.Token) ret1, _ := ret[1].(error) return ret0, ret1 } // ValidateTokenAndMergeWithUserInfo indicates an expected call of ValidateTokenAndMergeWithUserInfo. -func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateTokenAndMergeWithUserInfo(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateTokenAndMergeWithUserInfo(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateTokenAndMergeWithUserInfo", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateTokenAndMergeWithUserInfo), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateTokenAndMergeWithUserInfo", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateTokenAndMergeWithUserInfo), arg0, arg1, arg2, arg3, arg4) } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index a0691ed2..c471955f 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -77,7 +77,7 @@ type UpstreamOIDCIdentityProviderI interface { // ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated // tokens, or an error. - ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) + ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error) } type UpstreamLDAPIdentityProviderI interface { diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index e0fc4408..ffcdf384 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -11,6 +11,7 @@ import ( "github.com/ory/fosite" "github.com/ory/x/errorsx" + "golang.org/x/oauth2" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" @@ -101,7 +102,13 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error { s := session.Custom - if s.OIDC == nil || s.OIDC.UpstreamRefreshToken == "" { + if s.OIDC == nil { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + accessTokenStored := s.OIDC.UpstreamAccessToken != "" + refreshTokenStored := s.OIDC.UpstreamRefreshToken != "" + refreshTokenOrAccessTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored) + if !refreshTokenOrAccessTokenStored { return errorsx.WithStack(errMissingUpstreamSessionInternalError) } @@ -113,21 +120,28 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, plog.Debug("attempting upstream refresh request", "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) - refreshedTokens, err := p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) - if err != nil { - return errorsx.WithStack(errUpstreamRefreshError.WithHint( - "Upstream refresh failed.", - ).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) + var tokens *oauth2.Token + if refreshTokenStored { + tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) + if err != nil { + return errorsx.WithStack(errUpstreamRefreshError.WithHint( + "Upstream refresh failed.", + ).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) + } + } else { + tokens = &oauth2.Token{ + AccessToken: s.OIDC.UpstreamAccessToken, + } } // Upstream refresh may or may not return a new ID token. From the spec: // "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token." // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse - _, hasIDTok := refreshedTokens.Extra("id_token").(string) + _, hasIDTok := tokens.Extra("id_token").(string) // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at // least some providers do not include one, so we skip the nonce validation here (but not other validations). - validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, refreshedTokens, "", hasIDTok) + validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored) if err != nil { return errorsx.WithStack(errUpstreamRefreshError.WithHintf( "Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) @@ -166,10 +180,10 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, // Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in // the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding // overwriting the old one. - if refreshedTokens.RefreshToken != "" { + if tokens.RefreshToken != "" { plog.Debug("upstream refresh request did not return a new refresh token", "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) - s.OIDC.UpstreamRefreshToken = refreshedTokens.RefreshToken + s.OIDC.UpstreamRefreshToken = tokens.RefreshToken } return nil diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index 2a8f039a..11405ed4 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -225,7 +225,7 @@ type expectedUpstreamRefresh struct { type expectedUpstreamValidateTokens struct { performedByUpstreamName string - args *oidctestutil.ValidateTokenArgs + args *oidctestutil.ValidateTokenAndMergeWithUserInfoArgs } type tokenEndpointResponseExpectedValues struct { @@ -881,6 +881,7 @@ func TestRefreshGrant(t *testing.T) { oidcUpstreamInitialRefreshToken = "initial-upstream-refresh-token" oidcUpstreamRefreshedIDToken = "fake-refreshed-id-token" oidcUpstreamRefreshedRefreshToken = "fake-refreshed-refresh-token" + oidcUpstreamAccessToken = "fake-upstream-access-token" //nolint:gosec ldapUpstreamName = "some-ldap-idp" ldapUpstreamResourceUID = "ldap-resource-uid" @@ -904,7 +905,7 @@ func TestRefreshGrant(t *testing.T) { WithResourceUID(oidcUpstreamResourceUID) } - initialUpstreamOIDCCustomSessionData := func() *psession.CustomSessionData { + initialUpstreamOIDCRefreshTokenCustomSessionData := func() *psession.CustomSessionData { return &psession.CustomSessionData{ ProviderName: oidcUpstreamName, ProviderUID: oidcUpstreamResourceUID, @@ -917,8 +918,21 @@ func TestRefreshGrant(t *testing.T) { } } + initialUpstreamOIDCAccessTokenCustomSessionData := func() *psession.CustomSessionData { + return &psession.CustomSessionData{ + ProviderName: oidcUpstreamName, + ProviderUID: oidcUpstreamResourceUID, + ProviderType: oidcUpstreamType, + OIDC: &psession.OIDCSessionData{ + UpstreamAccessToken: oidcUpstreamAccessToken, + UpstreamSubject: goodUpstreamSubject, + UpstreamIssuer: goodIssuer, + }, + } + } + upstreamOIDCCustomSessionDataWithNewRefreshToken := func(newRefreshToken string) *psession.CustomSessionData { - sessionData := initialUpstreamOIDCCustomSessionData() + sessionData := initialUpstreamOIDCRefreshTokenCustomSessionData() sessionData.OIDC.UpstreamRefreshToken = newRefreshToken return sessionData } @@ -957,13 +971,15 @@ func TestRefreshGrant(t *testing.T) { } } - happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token) *expectedUpstreamValidateTokens { + happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token, requireIDToken bool) *expectedUpstreamValidateTokens { return &expectedUpstreamValidateTokens{ performedByUpstreamName: oidcUpstreamName, - args: &oidctestutil.ValidateTokenArgs{ + args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ Ctx: nil, // this will be filled in with the actual request context by the test below Tok: expectedTokens, ExpectedIDTokenNonce: "", // always expect empty string + RequireUserInfo: false, + RequireIDToken: requireIDToken, }, } } @@ -986,7 +1002,7 @@ func TestRefreshGrant(t *testing.T) { // Should always try to perform an upstream refresh. want.wantUpstreamRefreshCall = happyOIDCUpstreamRefreshCall() if expectToValidateToken != nil { - want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken) + want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken, true) } return want } @@ -1049,7 +1065,7 @@ func TestRefreshGrant(t *testing.T) { { name: "happy path refresh grant with openid scope granted (id token returned)", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "sub": goodUpstreamSubject, @@ -1057,9 +1073,9 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( @@ -1071,7 +1087,7 @@ func TestRefreshGrant(t *testing.T) { { name: "refresh grant with unchanged username claim", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "some-claim": "some-value", @@ -1081,9 +1097,9 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( @@ -1092,23 +1108,64 @@ func TestRefreshGrant(t *testing.T) { ), }, }, + { + name: "refresh grant when the customsessiondata has a stored access token and no stored refresh token", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim"). + WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ + IDToken: &oidctypes.IDToken{ + Claims: map[string]interface{}{ + "some-claim": "some-value", + "sub": goodUpstreamSubject, + "username-claim": goodUsername, + }, + }, + AccessToken: &oidctypes.AccessToken{ + Token: oidcUpstreamAccessToken, + }, + }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCAccessTokenCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCAccessTokenCustomSessionData()), + }, // do not want upstreamRefreshRequest??? + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "id_token", "access_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access"}, + wantGrantedScopes: []string{"openid", "offline_access"}, + wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{ + oidcUpstreamName, + &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token + ExpectedIDTokenNonce: "", // always expect empty string + RequireIDToken: false, + RequireUserInfo: true, + }, + }, + wantCustomSessionDataStored: initialUpstreamOIDCAccessTokenCustomSessionData(), // doesn't change when we refresh + }, + }, + }, { name: "happy path refresh grant without openid scope granted (no id token returned)", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{}, }, - }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), + }).WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusOK, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"}, - wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), + wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1118,7 +1175,7 @@ func TestRefreshGrant(t *testing.T) { wantRequestedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"}, wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), false), wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), }, }, @@ -1126,27 +1183,32 @@ func TestRefreshGrant(t *testing.T) { { name: "happy path refresh grant when the upstream refresh does not return a new ID token", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{}, }, }).WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ - want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( - upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), - refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), // expect ValidateTokenAndMergeWithUserInfo is called - ), + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusOK, + wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"}, + wantRequestedScopes: []string{"openid", "offline_access"}, + wantGrantedScopes: []string{"openid", "offline_access"}, + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), false), + wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), + }, }, }, { name: "happy path refresh grant when the upstream refresh does not return a new refresh token", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "sub": goodUpstreamSubject, @@ -1154,13 +1216,13 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDTokenWithoutRefreshToken()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( - initialUpstreamOIDCCustomSessionData(), // still has the initial refresh token stored + initialUpstreamOIDCRefreshTokenCustomSessionData(), // still has the initial refresh token stored refreshedUpstreamTokensWithIDTokenWithoutRefreshToken(), ), }, @@ -1168,7 +1230,7 @@ func TestRefreshGrant(t *testing.T) { { name: "when the refresh request adds a new scope to the list of requested scopes then it is ignored", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "sub": goodUpstreamSubject, @@ -1176,9 +1238,9 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { @@ -1193,7 +1255,7 @@ func TestRefreshGrant(t *testing.T) { { name: "when the refresh request removes a scope which was originally granted from the list of requested scopes then it is granted anyway", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "sub": goodUpstreamSubject, @@ -1201,14 +1263,14 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access pinniped:request-audience") }, want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusOK, wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), + wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1221,7 +1283,7 @@ func TestRefreshGrant(t *testing.T) { wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true), wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), }, }, @@ -1229,7 +1291,7 @@ func TestRefreshGrant(t *testing.T) { { name: "when the refresh request does not include a scope param then it gets all the same scopes as the original authorization request", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "sub": goodUpstreamSubject, @@ -1237,9 +1299,9 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { @@ -1255,14 +1317,14 @@ func TestRefreshGrant(t *testing.T) { name: "when a bad refresh token is sent in the refresh request", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusOK, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"}, - wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), + wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1279,14 +1341,14 @@ func TestRefreshGrant(t *testing.T) { name: "when the access token is sent as if it were a refresh token", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusOK, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"}, - wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), + wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1303,14 +1365,14 @@ func TestRefreshGrant(t *testing.T) { name: "when the wrong client ID is included in the refresh request", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, want: tokenEndpointResponseExpectedValues{ wantStatus: http.StatusOK, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"}, - wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), + wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(), }, }, refreshRequest: refreshRequestInputs{ @@ -1474,7 +1536,7 @@ func TestRefreshGrant(t *testing.T) { }, }, { - name: "when there is no OIDC refresh token in custom session data found in the session storage during the refresh request", + name: "when there is no OIDC refresh token nor access token in custom session data found in the session storage during the refresh request", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), authcodeExchange: authcodeExchangeInputs{ customSessionData: &psession.CustomSessionData{ @@ -1482,7 +1544,8 @@ func TestRefreshGrant(t *testing.T) { ProviderUID: oidcUpstreamResourceUID, ProviderType: oidcUpstreamType, OIDC: &psession.OIDCSessionData{ - UpstreamRefreshToken: "", // this should not happen in practice + UpstreamRefreshToken: "", // this should not happen in practice. we should always have exactly one of these. + UpstreamAccessToken: "", }, }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, @@ -1493,6 +1556,7 @@ func TestRefreshGrant(t *testing.T) { ProviderType: oidcUpstreamType, OIDC: &psession.OIDCSessionData{ UpstreamRefreshToken: "", // this should not happen in practice + UpstreamAccessToken: "", }, }, ), @@ -1573,9 +1637,9 @@ func TestRefreshGrant(t *testing.T) { idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ @@ -1595,17 +1659,17 @@ func TestRefreshGrant(t *testing.T) { idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). // This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go - WithValidateTokenError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))). + WithValidateTokenAndMergeWithUserInfoError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))). Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true), wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { @@ -1621,7 +1685,7 @@ func TestRefreshGrant(t *testing.T) { idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). // This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go - WithValidatedTokens(&oidctypes.Token{ + WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "sub": "something-different", @@ -1630,14 +1694,14 @@ func TestRefreshGrant(t *testing.T) { }). Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true), wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { @@ -1651,7 +1715,7 @@ func TestRefreshGrant(t *testing.T) { { name: "refresh grant with claims but not the subject claim", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "some-claim": "some-value", @@ -1659,14 +1723,14 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true), wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { @@ -1680,7 +1744,7 @@ func TestRefreshGrant(t *testing.T) { { name: "refresh grant with changed username claim", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "some-claim": "some-value", @@ -1690,14 +1754,14 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true), wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { @@ -1711,7 +1775,7 @@ func TestRefreshGrant(t *testing.T) { { name: "refresh grant with changed issuer claim", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( - upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ + upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{ IDToken: &oidctypes.IDToken{ Claims: map[string]interface{}{ "some-claim": "some-value", @@ -1721,14 +1785,14 @@ func TestRefreshGrant(t *testing.T) { }, }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), authcodeExchange: authcodeExchangeInputs{ - customSessionData: initialUpstreamOIDCCustomSessionData(), + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, - want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), - wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true), wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index fa290017..bed4685b 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -75,12 +75,14 @@ type RevokeRefreshTokenArgs struct { RefreshToken string } -// ValidateTokenArgs is used to spy on calls to -// TestUpstreamOIDCIdentityProvider.ValidateTokenFunc(). -type ValidateTokenArgs struct { +// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc(). +type ValidateTokenAndMergeWithUserInfoArgs struct { Ctx context.Context Tok *oauth2.Token ExpectedIDTokenNonce nonce.Nonce + RequireIDToken bool + RequireUserInfo bool } type ValidateRefreshArgs struct { @@ -175,7 +177,7 @@ type TestUpstreamOIDCIdentityProvider struct { RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error - ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) + ValidateTokenAndMergeWithUserInfoFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs @@ -185,8 +187,8 @@ type TestUpstreamOIDCIdentityProvider struct { performRefreshArgs []*PerformRefreshArgs revokeRefreshTokenCallCount int revokeRefreshTokenArgs []*RevokeRefreshTokenArgs - validateTokenCallCount int - validateTokenArgs []*ValidateTokenArgs + validateTokenAndMergeWithUserInfoCallCount int + validateTokenAndMergeWithUserInfoArgs []*ValidateTokenAndMergeWithUserInfoArgs } var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} @@ -323,28 +325,30 @@ func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *Rev return u.revokeRefreshTokenArgs[call] } -func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) { - if u.validateTokenArgs == nil { - u.validateTokenArgs = make([]*ValidateTokenArgs, 0) +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error) { + if u.validateTokenAndMergeWithUserInfoArgs == nil { + u.validateTokenAndMergeWithUserInfoArgs = make([]*ValidateTokenAndMergeWithUserInfoArgs, 0) } - u.validateTokenCallCount++ - u.validateTokenArgs = append(u.validateTokenArgs, &ValidateTokenArgs{ + u.validateTokenAndMergeWithUserInfoCallCount++ + u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{ Ctx: ctx, Tok: tok, ExpectedIDTokenNonce: expectedIDTokenNonce, + RequireIDToken: requireIDToken, + RequireUserInfo: requireUserInfo, }) - return u.ValidateTokenFunc(ctx, tok, expectedIDTokenNonce) + return u.ValidateTokenAndMergeWithUserInfoFunc(ctx, tok, expectedIDTokenNonce) } -func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenCallCount() int { - return u.validateTokenCallCount +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfoCallCount() int { + return u.validateTokenAndMergeWithUserInfoCallCount } -func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenArgs(call int) *ValidateTokenArgs { - if u.validateTokenArgs == nil { - u.validateTokenArgs = make([]*ValidateTokenArgs, 0) +func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfoArgs(call int) *ValidateTokenAndMergeWithUserInfoArgs { + if u.validateTokenAndMergeWithUserInfoArgs == nil { + u.validateTokenAndMergeWithUserInfoArgs = make([]*ValidateTokenAndMergeWithUserInfoArgs, 0) } - return u.validateTokenArgs[call] + return u.validateTokenAndMergeWithUserInfoArgs[call] } type UpstreamIDPListerBuilder struct { @@ -529,18 +533,18 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *te func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken( t *testing.T, expectedPerformedByUpstreamName string, - expectedArgs *ValidateTokenArgs, + expectedArgs *ValidateTokenAndMergeWithUserInfoArgs, ) { t.Helper() - var actualArgs *ValidateTokenArgs + var actualArgs *ValidateTokenAndMergeWithUserInfoArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - callCountOnThisUpstream := upstreamOIDC.validateTokenCallCount + callCountOnThisUpstream := upstreamOIDC.validateTokenAndMergeWithUserInfoCallCount actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name - actualArgs = upstreamOIDC.validateTokenArgs[0] + actualArgs = upstreamOIDC.validateTokenAndMergeWithUserInfoArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, @@ -556,7 +560,7 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes t.Helper() actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenAndMergeWithUserInfoCallCount } require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, "expected exactly zero calls to ValidateTokenAndMergeWithUserInfo()", @@ -605,26 +609,26 @@ func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder { } type TestUpstreamOIDCIdentityProviderBuilder struct { - name string - resourceUID types.UID - clientID string - scopes []string - idToken map[string]interface{} - refreshToken *oidctypes.RefreshToken - accessToken *oidctypes.AccessToken - usernameClaim string - groupsClaim string - refreshedTokens *oauth2.Token - validatedTokens *oidctypes.Token - authorizationURL url.URL - hasUserInfoURL bool - additionalAuthcodeParams map[string]string - allowPasswordGrant bool - authcodeExchangeErr error - passwordGrantErr error - performRefreshErr error - revokeRefreshTokenErr error - validateTokenErr error + name string + resourceUID types.UID + clientID string + scopes []string + idToken map[string]interface{} + refreshToken *oidctypes.RefreshToken + accessToken *oidctypes.AccessToken + usernameClaim string + groupsClaim string + refreshedTokens *oauth2.Token + validatedAndMergedWithUserInfoTokens *oidctypes.Token + authorizationURL url.URL + hasUserInfoURL bool + additionalAuthcodeParams map[string]string + allowPasswordGrant bool + authcodeExchangeErr error + passwordGrantErr error + performRefreshErr error + revokeRefreshTokenErr error + validateTokenAndMergeWithUserInfoErr error } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder { @@ -754,13 +758,13 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err er return u } -func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder { - u.validatedTokens = tokens +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedAndMergedWithUserInfoTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder { + u.validatedAndMergedWithUserInfoTokens = tokens return u } -func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { - u.validateTokenErr = err +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenAndMergeWithUserInfoError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.validateTokenAndMergeWithUserInfoErr = err return u } @@ -802,11 +806,11 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error { return u.revokeRefreshTokenErr }, - ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { - if u.validateTokenErr != nil { - return nil, u.validateTokenErr + ValidateTokenAndMergeWithUserInfoFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.validateTokenAndMergeWithUserInfoErr != nil { + return nil, u.validateTokenAndMergeWithUserInfoErr } - return u.validatedTokens, nil + return u.validatedAndMergedWithUserInfoTokens, nil }, } } diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 5a1006a7..94ce7502 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -126,7 +126,7 @@ func (p *ProviderConfig) PasswordCredentialsGrantAndValidateTokens(ctx context.C // There is no nonce to validate for a resource owner password credentials grant because it skips using // the authorize endpoint and goes straight to the token endpoint. const skipNonceValidation nonce.Nonce = "" - return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, skipNonceValidation, true) + return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, skipNonceValidation, true, false) } func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) { @@ -140,7 +140,7 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, return nil, err } - return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, expectedIDTokenNonce, true) + return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, expectedIDTokenNonce, true, false) } func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { @@ -252,7 +252,7 @@ func (p *ProviderConfig) tryRevokeRefreshToken( // ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response, // if the provider offers the userinfo endpoint. -func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) { +func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error) { var validatedClaims = make(map[string]interface{}) var idTokenExpiry time.Time @@ -268,7 +268,7 @@ func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, if len(idTokenSubject) > 0 || !requireIDToken { // only fetch userinfo if the ID token has a subject or if we are ignoring the id token completely. // otherwise, defer to existing ID token validation - if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims, requireIDToken); err != nil { + if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims, requireIDToken, requireUserInfo); err != nil { return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err) } } @@ -322,10 +322,10 @@ func (p *ProviderConfig) validateIDToken(ctx context.Context, tok *oauth2.Token, return idTokenExpiry, idTok, nil } -func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}, requireIDToken bool) error { +func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}, requireIDToken bool, requireUserInfo bool) error { idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string) - userInfo, err := p.maybeFetchUserInfo(ctx, tok) + userInfo, err := p.maybeFetchUserInfo(ctx, tok, requireUserInfo) if err != nil { return err } @@ -368,9 +368,13 @@ func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, t return nil } -func (p *ProviderConfig) maybeFetchUserInfo(ctx context.Context, tok *oauth2.Token) (*coreosoidc.UserInfo, error) { - // implementing the user info endpoint is not required, skip this logic when it is absent +func (p *ProviderConfig) maybeFetchUserInfo(ctx context.Context, tok *oauth2.Token, requireUserInfo bool) (*coreosoidc.UserInfo, error) { + // implementing the user info endpoint is not required by the OIDC spec, but we may require it in certain situations. if !p.HasUserInfoURL() { + if requireUserInfo { + // TODO should these all be http errors? + return nil, httperr.New(http.StatusInternalServerError, "userinfo endpoint not found, but is required") + } return nil, nil } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 0cd2a727..811b2300 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -622,6 +622,7 @@ func TestProviderConfig(t *testing.T) { tok *oauth2.Token nonce nonce.Nonce requireIDToken bool + requireUserInfo bool userInfo *oidc.UserInfo rawClaims []byte userInfoErr error @@ -707,6 +708,34 @@ func TestProviderConfig(t *testing.T) { }, }, }, + { + name: "userinfo is required, token with id, access and refresh tokens, valid nonce, and userinfo with a value that doesn't exist in the id token", + tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": goodIDToken}), + nonce: "some-nonce", + requireIDToken: true, + requireUserInfo: true, + rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), + userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`), + wantMergedTokens: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ + Token: "test-access-token", + Type: "test-token-type", + Expiry: metav1.NewTime(expiryTime), + }, + RefreshToken: &oidctypes.RefreshToken{ + Token: "test-initial-refresh-token", + }, + IDToken: &oidctypes.IDToken{ + Token: goodIDToken, + Claims: map[string]interface{}{ + "iss": "some-issuer", + "nonce": "some-nonce", + "sub": "some-subject", + "name": "Pinny TheSeal", + }, + }, + }, + }, { name: "claims from userinfo override id token claims", tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJzb21lLXN1YmplY3QiLCJuYW1lIjoiSm9obiBEb2UiLCJpc3MiOiJzb21lLWlzc3VlciIsIm5vbmNlIjoic29tZS1ub25jZSJ9.sBWi3_4cfGwrmMFZWkCghw4uvCnHN35h9xNX1gkwOtj6Oz_yKqpj7wfO4AqeWsRyrDGnkmIZbVuhAAJqPSi4GlNzN4NU8zh53PGDUpFlpDI1dvqDjIRb9iIEJpRIj34--Sz41H0ooxviIzvUdZFvQlaSzLOqgjR3ddHe2urhbtUuz_DsabP84AWo2DSg0y3ull6DRvk_DvzC6HNN8JwVi08fFvvV9BVq8kjdVeob7gajJkuGSTjsxNZGs5rbBuxBx0MZTQ8boR1fDNdG70GoIb4SsCoBSs7pZxtmGZPHInteY1SilHDDDmpQuE-LvSmvvPN_Cyk1d3eS-IR7hBbCAA"}), @@ -762,11 +791,12 @@ func TestProviderConfig(t *testing.T) { }, }, { - name: "token with id, access and refresh tokens and valid nonce, but no userinfo endpoint from discovery", - tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": goodIDToken}), - nonce: "some-nonce", - requireIDToken: true, - rawClaims: []byte(`{"not_the_userinfo_endpoint": "some-other-endpoint"}`), + name: "token with id, access and refresh tokens and valid nonce, but no userinfo endpoint from discovery and it's not required", + tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": goodIDToken}), + nonce: "some-nonce", + requireIDToken: true, + requireUserInfo: false, + rawClaims: []byte(`{"not_the_userinfo_endpoint": "some-other-endpoint"}`), wantMergedTokens: &oidctypes.Token{ AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", @@ -876,6 +906,23 @@ func TestProviderConfig(t *testing.T) { userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`), wantErr: "received response missing ID token", }, + { + name: "expected to have userinfo, but doesn't", + tok: testTokenWithoutIDToken, + nonce: "some-other-nonce", + requireUserInfo: true, + rawClaims: []byte(`{}`), + wantErr: "could not fetch user info claims: userinfo endpoint not found, but is required", + }, + { + name: "expected to have id token and userinfo, but doesn't have either", + tok: testTokenWithoutIDToken, + nonce: "some-other-nonce", + requireUserInfo: true, + requireIDToken: true, + rawClaims: []byte(`{}`), + wantErr: "received response missing ID token", + }, { name: "mismatched access token hash", tok: testTokenWithoutIDToken, @@ -936,7 +983,7 @@ func TestProviderConfig(t *testing.T) { userInfoErr: tt.userInfoErr, }, } - gotTok, err := p.ValidateTokenAndMergeWithUserInfo(context.Background(), tt.tok, tt.nonce, tt.requireIDToken) + gotTok, err := p.ValidateTokenAndMergeWithUserInfo(context.Background(), tt.tok, tt.nonce, tt.requireIDToken, tt.requireUserInfo) if tt.wantErr != "" { require.Error(t, err) require.Equal(t, tt.wantErr, err.Error()) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 325ec4a4..d9364689 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -1,4 +1,4 @@ -// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 // Package oidcclient implements a CLI OIDC login flow. @@ -822,7 +822,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true) + return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true, false) } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index f29250e8..8ee920d7 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -406,7 +406,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true). + ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). Return(&testToken, nil) mock.EXPECT(). PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). @@ -453,7 +453,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true). + ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). Return(nil, fmt.Errorf("some validation error")) mock.EXPECT(). PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token"). @@ -1648,7 +1648,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true). + ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). Return(&testToken, nil) mock.EXPECT(). PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).