Perform access token based refresh by fetching the userinfo

This commit is contained in:
Margo Crawford 2022-01-12 18:05:10 -08:00
parent 651d392b00
commit 62be761ef1
9 changed files with 286 additions and 153 deletions

View File

@ -244,16 +244,16 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0
} }
// ValidateTokenAndMergeWithUserInfo mocks base method. // ValidateTokenAndMergeWithUserInfo mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) ValidateTokenAndMergeWithUserInfo(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce, arg3 bool) (*oidctypes.Token, error) { func (m *MockUpstreamOIDCIdentityProviderI) ValidateTokenAndMergeWithUserInfo(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce, arg3, arg4 bool) (*oidctypes.Token, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateTokenAndMergeWithUserInfo", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "ValidateTokenAndMergeWithUserInfo", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(*oidctypes.Token) ret0, _ := ret[0].(*oidctypes.Token)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// ValidateTokenAndMergeWithUserInfo indicates an expected call of ValidateTokenAndMergeWithUserInfo. // ValidateTokenAndMergeWithUserInfo indicates an expected call of ValidateTokenAndMergeWithUserInfo.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateTokenAndMergeWithUserInfo(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateTokenAndMergeWithUserInfo(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateTokenAndMergeWithUserInfo", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateTokenAndMergeWithUserInfo), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateTokenAndMergeWithUserInfo", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateTokenAndMergeWithUserInfo), arg0, arg1, arg2, arg3, arg4)
} }

View File

@ -77,7 +77,7 @@ type UpstreamOIDCIdentityProviderI interface {
// ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response // ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
// tokens, or an error. // tokens, or an error.
ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error)
} }
type UpstreamLDAPIdentityProviderI interface { type UpstreamLDAPIdentityProviderI interface {

View File

@ -11,6 +11,7 @@ import (
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/x/errorsx" "github.com/ory/x/errorsx"
"golang.org/x/oauth2"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
@ -101,7 +102,13 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
s := session.Custom s := session.Custom
if s.OIDC == nil || s.OIDC.UpstreamRefreshToken == "" { if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
accessTokenStored := s.OIDC.UpstreamAccessToken != ""
refreshTokenStored := s.OIDC.UpstreamRefreshToken != ""
refreshTokenOrAccessTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
if !refreshTokenOrAccessTokenStored {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
} }
@ -113,21 +120,28 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
plog.Debug("attempting upstream refresh request", plog.Debug("attempting upstream refresh request",
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
refreshedTokens, err := p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) var tokens *oauth2.Token
if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint( return errorsx.WithStack(errUpstreamRefreshError.WithHint(
"Upstream refresh failed.", "Upstream refresh failed.",
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) ).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
} }
} else {
tokens = &oauth2.Token{
AccessToken: s.OIDC.UpstreamAccessToken,
}
}
// Upstream refresh may or may not return a new ID token. From the spec: // Upstream refresh may or may not return a new ID token. From the spec:
// "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token." // "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token."
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
_, hasIDTok := refreshedTokens.Extra("id_token").(string) _, hasIDTok := tokens.Extra("id_token").(string)
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at
// least some providers do not include one, so we skip the nonce validation here (but not other validations). // least some providers do not include one, so we skip the nonce validation here (but not other validations).
validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, refreshedTokens, "", hasIDTok) validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) "Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
@ -166,10 +180,10 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in // Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding // the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
// overwriting the old one. // overwriting the old one.
if refreshedTokens.RefreshToken != "" { if tokens.RefreshToken != "" {
plog.Debug("upstream refresh request did not return a new refresh token", plog.Debug("upstream refresh request did not return a new refresh token",
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
s.OIDC.UpstreamRefreshToken = refreshedTokens.RefreshToken s.OIDC.UpstreamRefreshToken = tokens.RefreshToken
} }
return nil return nil

View File

@ -225,7 +225,7 @@ type expectedUpstreamRefresh struct {
type expectedUpstreamValidateTokens struct { type expectedUpstreamValidateTokens struct {
performedByUpstreamName string performedByUpstreamName string
args *oidctestutil.ValidateTokenArgs args *oidctestutil.ValidateTokenAndMergeWithUserInfoArgs
} }
type tokenEndpointResponseExpectedValues struct { type tokenEndpointResponseExpectedValues struct {
@ -881,6 +881,7 @@ func TestRefreshGrant(t *testing.T) {
oidcUpstreamInitialRefreshToken = "initial-upstream-refresh-token" oidcUpstreamInitialRefreshToken = "initial-upstream-refresh-token"
oidcUpstreamRefreshedIDToken = "fake-refreshed-id-token" oidcUpstreamRefreshedIDToken = "fake-refreshed-id-token"
oidcUpstreamRefreshedRefreshToken = "fake-refreshed-refresh-token" oidcUpstreamRefreshedRefreshToken = "fake-refreshed-refresh-token"
oidcUpstreamAccessToken = "fake-upstream-access-token" //nolint:gosec
ldapUpstreamName = "some-ldap-idp" ldapUpstreamName = "some-ldap-idp"
ldapUpstreamResourceUID = "ldap-resource-uid" ldapUpstreamResourceUID = "ldap-resource-uid"
@ -904,7 +905,7 @@ func TestRefreshGrant(t *testing.T) {
WithResourceUID(oidcUpstreamResourceUID) WithResourceUID(oidcUpstreamResourceUID)
} }
initialUpstreamOIDCCustomSessionData := func() *psession.CustomSessionData { initialUpstreamOIDCRefreshTokenCustomSessionData := func() *psession.CustomSessionData {
return &psession.CustomSessionData{ return &psession.CustomSessionData{
ProviderName: oidcUpstreamName, ProviderName: oidcUpstreamName,
ProviderUID: oidcUpstreamResourceUID, ProviderUID: oidcUpstreamResourceUID,
@ -917,8 +918,21 @@ func TestRefreshGrant(t *testing.T) {
} }
} }
initialUpstreamOIDCAccessTokenCustomSessionData := func() *psession.CustomSessionData {
return &psession.CustomSessionData{
ProviderName: oidcUpstreamName,
ProviderUID: oidcUpstreamResourceUID,
ProviderType: oidcUpstreamType,
OIDC: &psession.OIDCSessionData{
UpstreamAccessToken: oidcUpstreamAccessToken,
UpstreamSubject: goodUpstreamSubject,
UpstreamIssuer: goodIssuer,
},
}
}
upstreamOIDCCustomSessionDataWithNewRefreshToken := func(newRefreshToken string) *psession.CustomSessionData { upstreamOIDCCustomSessionDataWithNewRefreshToken := func(newRefreshToken string) *psession.CustomSessionData {
sessionData := initialUpstreamOIDCCustomSessionData() sessionData := initialUpstreamOIDCRefreshTokenCustomSessionData()
sessionData.OIDC.UpstreamRefreshToken = newRefreshToken sessionData.OIDC.UpstreamRefreshToken = newRefreshToken
return sessionData return sessionData
} }
@ -957,13 +971,15 @@ func TestRefreshGrant(t *testing.T) {
} }
} }
happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token) *expectedUpstreamValidateTokens { happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token, requireIDToken bool) *expectedUpstreamValidateTokens {
return &expectedUpstreamValidateTokens{ return &expectedUpstreamValidateTokens{
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.ValidateTokenArgs{ args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: expectedTokens, Tok: expectedTokens,
ExpectedIDTokenNonce: "", // always expect empty string ExpectedIDTokenNonce: "", // always expect empty string
RequireUserInfo: false,
RequireIDToken: requireIDToken,
}, },
} }
} }
@ -986,7 +1002,7 @@ func TestRefreshGrant(t *testing.T) {
// Should always try to perform an upstream refresh. // Should always try to perform an upstream refresh.
want.wantUpstreamRefreshCall = happyOIDCUpstreamRefreshCall() want.wantUpstreamRefreshCall = happyOIDCUpstreamRefreshCall()
if expectToValidateToken != nil { if expectToValidateToken != nil {
want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken) want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken, true)
} }
return want return want
} }
@ -1049,7 +1065,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "happy path refresh grant with openid scope granted (id token returned)", name: "happy path refresh grant with openid scope granted (id token returned)",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"sub": goodUpstreamSubject, "sub": goodUpstreamSubject,
@ -1057,9 +1073,9 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( want: happyRefreshTokenResponseForOpenIDAndOfflineAccess(
@ -1071,7 +1087,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "refresh grant with unchanged username claim", name: "refresh grant with unchanged username claim",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"some-claim": "some-value", "some-claim": "some-value",
@ -1081,9 +1097,9 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( want: happyRefreshTokenResponseForOpenIDAndOfflineAccess(
@ -1093,22 +1109,63 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
{ {
name: "happy path refresh grant without openid scope granted (no id token returned)", name: "refresh grant when the customsessiondata has a stored access token and no stored refresh token",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").
WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{}, Claims: map[string]interface{}{
"some-claim": "some-value",
"sub": goodUpstreamSubject,
"username-claim": goodUsername,
},
},
AccessToken: &oidctypes.AccessToken{
Token: oidcUpstreamAccessToken,
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCAccessTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCAccessTokenCustomSessionData()),
}, // do not want upstreamRefreshRequest???
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "id_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"openid", "offline_access"},
wantGrantedScopes: []string{"openid", "offline_access"},
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
oidcUpstreamName,
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
ExpectedIDTokenNonce: "", // always expect empty string
RequireIDToken: false,
RequireUserInfo: true,
},
},
wantCustomSessionDataStored: initialUpstreamOIDCAccessTokenCustomSessionData(), // doesn't change when we refresh
},
},
},
{
name: "happy path refresh grant without openid scope granted (no id token returned)",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{},
},
}).WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") },
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"offline_access"}, wantRequestedScopes: []string{"offline_access"},
wantGrantedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"},
wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(),
}, },
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
@ -1118,7 +1175,7 @@ func TestRefreshGrant(t *testing.T) {
wantRequestedScopes: []string{"offline_access"}, wantRequestedScopes: []string{"offline_access"},
wantGrantedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"},
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), false),
wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken),
}, },
}, },
@ -1126,27 +1183,32 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "happy path refresh grant when the upstream refresh does not return a new ID token", name: "happy path refresh grant when the upstream refresh does not return a new ID token",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{}, Claims: map[string]interface{}{},
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( want: tokenEndpointResponseExpectedValues{
upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), wantStatus: http.StatusOK,
refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), // expect ValidateTokenAndMergeWithUserInfo is called wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
), wantRequestedScopes: []string{"openid", "offline_access"},
wantGrantedScopes: []string{"openid", "offline_access"},
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), false),
wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken),
},
}, },
}, },
{ {
name: "happy path refresh grant when the upstream refresh does not return a new refresh token", name: "happy path refresh grant when the upstream refresh does not return a new refresh token",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"sub": goodUpstreamSubject, "sub": goodUpstreamSubject,
@ -1154,13 +1216,13 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDTokenWithoutRefreshToken()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDTokenWithoutRefreshToken()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( want: happyRefreshTokenResponseForOpenIDAndOfflineAccess(
initialUpstreamOIDCCustomSessionData(), // still has the initial refresh token stored initialUpstreamOIDCRefreshTokenCustomSessionData(), // still has the initial refresh token stored
refreshedUpstreamTokensWithIDTokenWithoutRefreshToken(), refreshedUpstreamTokensWithIDTokenWithoutRefreshToken(),
), ),
}, },
@ -1168,7 +1230,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "when the refresh request adds a new scope to the list of requested scopes then it is ignored", name: "when the refresh request adds a new scope to the list of requested scopes then it is ignored",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"sub": goodUpstreamSubject, "sub": goodUpstreamSubject,
@ -1176,9 +1238,9 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) {
@ -1193,7 +1255,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "when the refresh request removes a scope which was originally granted from the list of requested scopes then it is granted anyway", name: "when the refresh request removes a scope which was originally granted from the list of requested scopes then it is granted anyway",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"sub": goodUpstreamSubject, "sub": goodUpstreamSubject,
@ -1201,14 +1263,14 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access pinniped:request-audience") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access pinniped:request-audience") },
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"},
wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"},
wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(),
}, },
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
@ -1221,7 +1283,7 @@ func TestRefreshGrant(t *testing.T) {
wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"},
wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"},
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true),
wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken),
}, },
}, },
@ -1229,7 +1291,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "when the refresh request does not include a scope param then it gets all the same scopes as the original authorization request", name: "when the refresh request does not include a scope param then it gets all the same scopes as the original authorization request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"sub": goodUpstreamSubject, "sub": goodUpstreamSubject,
@ -1237,9 +1299,9 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) { modifyTokenRequest: func(r *http.Request, refreshToken string, accessToken string) {
@ -1255,14 +1317,14 @@ func TestRefreshGrant(t *testing.T) {
name: "when a bad refresh token is sent in the refresh request", name: "when a bad refresh token is sent in the refresh request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") },
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"offline_access"}, wantRequestedScopes: []string{"offline_access"},
wantGrantedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"},
wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(),
}, },
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
@ -1279,14 +1341,14 @@ func TestRefreshGrant(t *testing.T) {
name: "when the access token is sent as if it were a refresh token", name: "when the access token is sent as if it were a refresh token",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") },
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"offline_access"}, wantRequestedScopes: []string{"offline_access"},
wantGrantedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"},
wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(),
}, },
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
@ -1303,14 +1365,14 @@ func TestRefreshGrant(t *testing.T) {
name: "when the wrong client ID is included in the refresh request", name: "when the wrong client ID is included in the refresh request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "offline_access") },
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"offline_access"}, wantRequestedScopes: []string{"offline_access"},
wantGrantedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"},
wantCustomSessionDataStored: initialUpstreamOIDCCustomSessionData(), wantCustomSessionDataStored: initialUpstreamOIDCRefreshTokenCustomSessionData(),
}, },
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
@ -1474,7 +1536,7 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
{ {
name: "when there is no OIDC refresh token in custom session data found in the session storage during the refresh request", name: "when there is no OIDC refresh token nor access token in custom session data found in the session storage during the refresh request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()), idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: &psession.CustomSessionData{ customSessionData: &psession.CustomSessionData{
@ -1482,7 +1544,8 @@ func TestRefreshGrant(t *testing.T) {
ProviderUID: oidcUpstreamResourceUID, ProviderUID: oidcUpstreamResourceUID,
ProviderType: oidcUpstreamType, ProviderType: oidcUpstreamType,
OIDC: &psession.OIDCSessionData{ OIDC: &psession.OIDCSessionData{
UpstreamRefreshToken: "", // this should not happen in practice UpstreamRefreshToken: "", // this should not happen in practice. we should always have exactly one of these.
UpstreamAccessToken: "",
}, },
}, },
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
@ -1493,6 +1556,7 @@ func TestRefreshGrant(t *testing.T) {
ProviderType: oidcUpstreamType, ProviderType: oidcUpstreamType,
OIDC: &psession.OIDCSessionData{ OIDC: &psession.OIDCSessionData{
UpstreamRefreshToken: "", // this should not happen in practice UpstreamRefreshToken: "", // this should not happen in practice
UpstreamAccessToken: "",
}, },
}, },
), ),
@ -1573,9 +1637,9 @@ func TestRefreshGrant(t *testing.T) {
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
@ -1595,17 +1659,17 @@ func TestRefreshGrant(t *testing.T) {
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).
// This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go // This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go
WithValidateTokenError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))). WithValidateTokenAndMergeWithUserInfoError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))).
Build()), Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -1621,7 +1685,7 @@ func TestRefreshGrant(t *testing.T) {
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).
// This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go // This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go
WithValidatedTokens(&oidctypes.Token{ WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"sub": "something-different", "sub": "something-different",
@ -1630,14 +1694,14 @@ func TestRefreshGrant(t *testing.T) {
}). }).
Build()), Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -1651,7 +1715,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "refresh grant with claims but not the subject claim", name: "refresh grant with claims but not the subject claim",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"some-claim": "some-value", "some-claim": "some-value",
@ -1659,14 +1723,14 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -1680,7 +1744,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "refresh grant with changed username claim", name: "refresh grant with changed username claim",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"some-claim": "some-value", "some-claim": "some-value",
@ -1690,14 +1754,14 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -1711,7 +1775,7 @@ func TestRefreshGrant(t *testing.T) {
{ {
name: "refresh grant with changed issuer claim", name: "refresh grant with changed issuer claim",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedAndMergedWithUserInfoTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"some-claim": "some-value", "some-claim": "some-value",
@ -1721,14 +1785,14 @@ func TestRefreshGrant(t *testing.T) {
}, },
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(), customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens(), true),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {

View File

@ -75,12 +75,14 @@ type RevokeRefreshTokenArgs struct {
RefreshToken string RefreshToken string
} }
// ValidateTokenArgs is used to spy on calls to // ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ValidateTokenFunc(). // TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
type ValidateTokenArgs struct { type ValidateTokenAndMergeWithUserInfoArgs struct {
Ctx context.Context Ctx context.Context
Tok *oauth2.Token Tok *oauth2.Token
ExpectedIDTokenNonce nonce.Nonce ExpectedIDTokenNonce nonce.Nonce
RequireIDToken bool
RequireUserInfo bool
} }
type ValidateRefreshArgs struct { type ValidateRefreshArgs struct {
@ -175,7 +177,7 @@ type TestUpstreamOIDCIdentityProvider struct {
RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) ValidateTokenAndMergeWithUserInfoFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensCallCount int
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
@ -185,8 +187,8 @@ type TestUpstreamOIDCIdentityProvider struct {
performRefreshArgs []*PerformRefreshArgs performRefreshArgs []*PerformRefreshArgs
revokeRefreshTokenCallCount int revokeRefreshTokenCallCount int
revokeRefreshTokenArgs []*RevokeRefreshTokenArgs revokeRefreshTokenArgs []*RevokeRefreshTokenArgs
validateTokenCallCount int validateTokenAndMergeWithUserInfoCallCount int
validateTokenArgs []*ValidateTokenArgs validateTokenAndMergeWithUserInfoArgs []*ValidateTokenAndMergeWithUserInfoArgs
} }
var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{}
@ -323,28 +325,30 @@ func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *Rev
return u.revokeRefreshTokenArgs[call] return u.revokeRefreshTokenArgs[call]
} }
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) { func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error) {
if u.validateTokenArgs == nil { if u.validateTokenAndMergeWithUserInfoArgs == nil {
u.validateTokenArgs = make([]*ValidateTokenArgs, 0) u.validateTokenAndMergeWithUserInfoArgs = make([]*ValidateTokenAndMergeWithUserInfoArgs, 0)
} }
u.validateTokenCallCount++ u.validateTokenAndMergeWithUserInfoCallCount++
u.validateTokenArgs = append(u.validateTokenArgs, &ValidateTokenArgs{ u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
Ctx: ctx, Ctx: ctx,
Tok: tok, Tok: tok,
ExpectedIDTokenNonce: expectedIDTokenNonce, ExpectedIDTokenNonce: expectedIDTokenNonce,
RequireIDToken: requireIDToken,
RequireUserInfo: requireUserInfo,
}) })
return u.ValidateTokenFunc(ctx, tok, expectedIDTokenNonce) return u.ValidateTokenAndMergeWithUserInfoFunc(ctx, tok, expectedIDTokenNonce)
} }
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenCallCount() int { func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfoCallCount() int {
return u.validateTokenCallCount return u.validateTokenAndMergeWithUserInfoCallCount
} }
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenArgs(call int) *ValidateTokenArgs { func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfoArgs(call int) *ValidateTokenAndMergeWithUserInfoArgs {
if u.validateTokenArgs == nil { if u.validateTokenAndMergeWithUserInfoArgs == nil {
u.validateTokenArgs = make([]*ValidateTokenArgs, 0) u.validateTokenAndMergeWithUserInfoArgs = make([]*ValidateTokenAndMergeWithUserInfoArgs, 0)
} }
return u.validateTokenArgs[call] return u.validateTokenAndMergeWithUserInfoArgs[call]
} }
type UpstreamIDPListerBuilder struct { type UpstreamIDPListerBuilder struct {
@ -529,18 +533,18 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *te
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken( func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken(
t *testing.T, t *testing.T,
expectedPerformedByUpstreamName string, expectedPerformedByUpstreamName string,
expectedArgs *ValidateTokenArgs, expectedArgs *ValidateTokenAndMergeWithUserInfoArgs,
) { ) {
t.Helper() t.Helper()
var actualArgs *ValidateTokenArgs var actualArgs *ValidateTokenAndMergeWithUserInfoArgs
var actualNameOfUpstreamWhichMadeCall string var actualNameOfUpstreamWhichMadeCall string
actualCallCountAcrossAllOIDCUpstreams := 0 actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.validateTokenCallCount callCountOnThisUpstream := upstreamOIDC.validateTokenAndMergeWithUserInfoCallCount
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.validateTokenArgs[0] actualArgs = upstreamOIDC.validateTokenAndMergeWithUserInfoArgs[0]
} }
} }
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
@ -556,7 +560,7 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes
t.Helper() t.Helper()
actualCallCountAcrossAllOIDCUpstreams := 0 actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenAndMergeWithUserInfoCallCount
} }
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
"expected exactly zero calls to ValidateTokenAndMergeWithUserInfo()", "expected exactly zero calls to ValidateTokenAndMergeWithUserInfo()",
@ -615,7 +619,7 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
usernameClaim string usernameClaim string
groupsClaim string groupsClaim string
refreshedTokens *oauth2.Token refreshedTokens *oauth2.Token
validatedTokens *oidctypes.Token validatedAndMergedWithUserInfoTokens *oidctypes.Token
authorizationURL url.URL authorizationURL url.URL
hasUserInfoURL bool hasUserInfoURL bool
additionalAuthcodeParams map[string]string additionalAuthcodeParams map[string]string
@ -624,7 +628,7 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
passwordGrantErr error passwordGrantErr error
performRefreshErr error performRefreshErr error
revokeRefreshTokenErr error revokeRefreshTokenErr error
validateTokenErr error validateTokenAndMergeWithUserInfoErr error
} }
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder { func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder {
@ -754,13 +758,13 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err er
return u return u
} }
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder { func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedAndMergedWithUserInfoTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder {
u.validatedTokens = tokens u.validatedAndMergedWithUserInfoTokens = tokens
return u return u
} }
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenAndMergeWithUserInfoError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
u.validateTokenErr = err u.validateTokenAndMergeWithUserInfoErr = err
return u return u
} }
@ -802,11 +806,11 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error { RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error {
return u.revokeRefreshTokenErr return u.revokeRefreshTokenErr
}, },
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { ValidateTokenAndMergeWithUserInfoFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
if u.validateTokenErr != nil { if u.validateTokenAndMergeWithUserInfoErr != nil {
return nil, u.validateTokenErr return nil, u.validateTokenAndMergeWithUserInfoErr
} }
return u.validatedTokens, nil return u.validatedAndMergedWithUserInfoTokens, nil
}, },
} }
} }

View File

@ -126,7 +126,7 @@ func (p *ProviderConfig) PasswordCredentialsGrantAndValidateTokens(ctx context.C
// There is no nonce to validate for a resource owner password credentials grant because it skips using // There is no nonce to validate for a resource owner password credentials grant because it skips using
// the authorize endpoint and goes straight to the token endpoint. // the authorize endpoint and goes straight to the token endpoint.
const skipNonceValidation nonce.Nonce = "" const skipNonceValidation nonce.Nonce = ""
return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, skipNonceValidation, true) return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, skipNonceValidation, true, false)
} }
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) { func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) {
@ -140,7 +140,7 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context,
return nil, err return nil, err
} }
return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, expectedIDTokenNonce, true) return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, expectedIDTokenNonce, true, false)
} }
func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
@ -252,7 +252,7 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
// ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response, // ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response,
// if the provider offers the userinfo endpoint. // if the provider offers the userinfo endpoint.
func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) { func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error) {
var validatedClaims = make(map[string]interface{}) var validatedClaims = make(map[string]interface{})
var idTokenExpiry time.Time var idTokenExpiry time.Time
@ -268,7 +268,7 @@ func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context,
if len(idTokenSubject) > 0 || !requireIDToken { if len(idTokenSubject) > 0 || !requireIDToken {
// only fetch userinfo if the ID token has a subject or if we are ignoring the id token completely. // only fetch userinfo if the ID token has a subject or if we are ignoring the id token completely.
// otherwise, defer to existing ID token validation // otherwise, defer to existing ID token validation
if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims, requireIDToken); err != nil { if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims, requireIDToken, requireUserInfo); err != nil {
return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err) return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err)
} }
} }
@ -322,10 +322,10 @@ func (p *ProviderConfig) validateIDToken(ctx context.Context, tok *oauth2.Token,
return idTokenExpiry, idTok, nil return idTokenExpiry, idTok, nil
} }
func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}, requireIDToken bool) error { func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}, requireIDToken bool, requireUserInfo bool) error {
idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string) idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string)
userInfo, err := p.maybeFetchUserInfo(ctx, tok) userInfo, err := p.maybeFetchUserInfo(ctx, tok, requireUserInfo)
if err != nil { if err != nil {
return err return err
} }
@ -368,9 +368,13 @@ func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, t
return nil return nil
} }
func (p *ProviderConfig) maybeFetchUserInfo(ctx context.Context, tok *oauth2.Token) (*coreosoidc.UserInfo, error) { func (p *ProviderConfig) maybeFetchUserInfo(ctx context.Context, tok *oauth2.Token, requireUserInfo bool) (*coreosoidc.UserInfo, error) {
// implementing the user info endpoint is not required, skip this logic when it is absent // implementing the user info endpoint is not required by the OIDC spec, but we may require it in certain situations.
if !p.HasUserInfoURL() { if !p.HasUserInfoURL() {
if requireUserInfo {
// TODO should these all be http errors?
return nil, httperr.New(http.StatusInternalServerError, "userinfo endpoint not found, but is required")
}
return nil, nil return nil, nil
} }

View File

@ -622,6 +622,7 @@ func TestProviderConfig(t *testing.T) {
tok *oauth2.Token tok *oauth2.Token
nonce nonce.Nonce nonce nonce.Nonce
requireIDToken bool requireIDToken bool
requireUserInfo bool
userInfo *oidc.UserInfo userInfo *oidc.UserInfo
rawClaims []byte rawClaims []byte
userInfoErr error userInfoErr error
@ -707,6 +708,34 @@ func TestProviderConfig(t *testing.T) {
}, },
}, },
}, },
{
name: "userinfo is required, token with id, access and refresh tokens, valid nonce, and userinfo with a value that doesn't exist in the id token",
tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": goodIDToken}),
nonce: "some-nonce",
requireIDToken: true,
requireUserInfo: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{
Token: "test-access-token",
Type: "test-token-type",
Expiry: metav1.NewTime(expiryTime),
},
RefreshToken: &oidctypes.RefreshToken{
Token: "test-initial-refresh-token",
},
IDToken: &oidctypes.IDToken{
Token: goodIDToken,
Claims: map[string]interface{}{
"iss": "some-issuer",
"nonce": "some-nonce",
"sub": "some-subject",
"name": "Pinny TheSeal",
},
},
},
},
{ {
name: "claims from userinfo override id token claims", name: "claims from userinfo override id token claims",
tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJzb21lLXN1YmplY3QiLCJuYW1lIjoiSm9obiBEb2UiLCJpc3MiOiJzb21lLWlzc3VlciIsIm5vbmNlIjoic29tZS1ub25jZSJ9.sBWi3_4cfGwrmMFZWkCghw4uvCnHN35h9xNX1gkwOtj6Oz_yKqpj7wfO4AqeWsRyrDGnkmIZbVuhAAJqPSi4GlNzN4NU8zh53PGDUpFlpDI1dvqDjIRb9iIEJpRIj34--Sz41H0ooxviIzvUdZFvQlaSzLOqgjR3ddHe2urhbtUuz_DsabP84AWo2DSg0y3ull6DRvk_DvzC6HNN8JwVi08fFvvV9BVq8kjdVeob7gajJkuGSTjsxNZGs5rbBuxBx0MZTQ8boR1fDNdG70GoIb4SsCoBSs7pZxtmGZPHInteY1SilHDDDmpQuE-LvSmvvPN_Cyk1d3eS-IR7hBbCAA"}), tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJzb21lLXN1YmplY3QiLCJuYW1lIjoiSm9obiBEb2UiLCJpc3MiOiJzb21lLWlzc3VlciIsIm5vbmNlIjoic29tZS1ub25jZSJ9.sBWi3_4cfGwrmMFZWkCghw4uvCnHN35h9xNX1gkwOtj6Oz_yKqpj7wfO4AqeWsRyrDGnkmIZbVuhAAJqPSi4GlNzN4NU8zh53PGDUpFlpDI1dvqDjIRb9iIEJpRIj34--Sz41H0ooxviIzvUdZFvQlaSzLOqgjR3ddHe2urhbtUuz_DsabP84AWo2DSg0y3ull6DRvk_DvzC6HNN8JwVi08fFvvV9BVq8kjdVeob7gajJkuGSTjsxNZGs5rbBuxBx0MZTQ8boR1fDNdG70GoIb4SsCoBSs7pZxtmGZPHInteY1SilHDDDmpQuE-LvSmvvPN_Cyk1d3eS-IR7hBbCAA"}),
@ -762,10 +791,11 @@ func TestProviderConfig(t *testing.T) {
}, },
}, },
{ {
name: "token with id, access and refresh tokens and valid nonce, but no userinfo endpoint from discovery", name: "token with id, access and refresh tokens and valid nonce, but no userinfo endpoint from discovery and it's not required",
tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": goodIDToken}), tok: testTokenWithoutIDToken.WithExtra(map[string]interface{}{"id_token": goodIDToken}),
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
requireUserInfo: false,
rawClaims: []byte(`{"not_the_userinfo_endpoint": "some-other-endpoint"}`), rawClaims: []byte(`{"not_the_userinfo_endpoint": "some-other-endpoint"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
@ -876,6 +906,23 @@ func TestProviderConfig(t *testing.T) {
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantErr: "received response missing ID token", wantErr: "received response missing ID token",
}, },
{
name: "expected to have userinfo, but doesn't",
tok: testTokenWithoutIDToken,
nonce: "some-other-nonce",
requireUserInfo: true,
rawClaims: []byte(`{}`),
wantErr: "could not fetch user info claims: userinfo endpoint not found, but is required",
},
{
name: "expected to have id token and userinfo, but doesn't have either",
tok: testTokenWithoutIDToken,
nonce: "some-other-nonce",
requireUserInfo: true,
requireIDToken: true,
rawClaims: []byte(`{}`),
wantErr: "received response missing ID token",
},
{ {
name: "mismatched access token hash", name: "mismatched access token hash",
tok: testTokenWithoutIDToken, tok: testTokenWithoutIDToken,
@ -936,7 +983,7 @@ func TestProviderConfig(t *testing.T) {
userInfoErr: tt.userInfoErr, userInfoErr: tt.userInfoErr,
}, },
} }
gotTok, err := p.ValidateTokenAndMergeWithUserInfo(context.Background(), tt.tok, tt.nonce, tt.requireIDToken) gotTok, err := p.ValidateTokenAndMergeWithUserInfo(context.Background(), tt.tok, tt.nonce, tt.requireIDToken, tt.requireUserInfo)
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(t, err) require.Error(t, err)
require.Equal(t, tt.wantErr, err.Error()) require.Equal(t, tt.wantErr, err.Error())

View File

@ -1,4 +1,4 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package oidcclient implements a CLI OIDC login flow. // Package oidcclient implements a CLI OIDC login flow.
@ -822,7 +822,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations). // some providers do not include one, so we skip the nonce validation here (but not other validations).
return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true) return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true, false)
} }
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {

View File

@ -406,7 +406,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true). ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false).
Return(&testToken, nil) Return(&testToken, nil)
mock.EXPECT(). mock.EXPECT().
PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).
@ -453,7 +453,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true). ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false).
Return(nil, fmt.Errorf("some validation error")) Return(nil, fmt.Errorf("some validation error"))
mock.EXPECT(). mock.EXPECT().
PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token"). PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token").
@ -1648,7 +1648,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true). ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false).
Return(&testToken, nil) Return(&testToken, nil)
mock.EXPECT(). mock.EXPECT().
PerformRefresh(gomock.Any(), testToken.RefreshToken.Token). PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).