rename ValidateToken to ValidateTokenAndMergeWithUserInfo to better reflect what it's doing

Also changed a few comments and small things
This commit is contained in:
Margo Crawford 2021-12-16 12:53:49 -08:00
parent c9cf13a01f
commit f2d2144932
9 changed files with 62 additions and 78 deletions

View File

@ -14,11 +14,12 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
oauth2 "golang.org/x/oauth2"
types "k8s.io/apimachinery/pkg/types"
nonce "go.pinniped.dev/pkg/oidcclient/nonce" nonce "go.pinniped.dev/pkg/oidcclient/nonce"
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes" oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
pkce "go.pinniped.dev/pkg/oidcclient/pkce" pkce "go.pinniped.dev/pkg/oidcclient/pkce"
oauth2 "golang.org/x/oauth2"
types "k8s.io/apimachinery/pkg/types"
) )
// MockUpstreamOIDCIdentityProviderI is a mock of UpstreamOIDCIdentityProviderI interface. // MockUpstreamOIDCIdentityProviderI is a mock of UpstreamOIDCIdentityProviderI interface.
@ -230,9 +231,9 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0
} }
// ValidateToken mocks base method. // ValidateToken mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce, arg3 bool) (*oidctypes.Token, error) { func (m *MockUpstreamOIDCIdentityProviderI) ValidateTokenAndMergeWithUserInfo(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce, arg3 bool) (*oidctypes.Token, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "ValidateTokenAndMergeWithUserInfo", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*oidctypes.Token) ret0, _ := ret[0].(*oidctypes.Token)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
@ -241,5 +242,5 @@ func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context,
// ValidateToken indicates an expected call of ValidateToken. // ValidateToken indicates an expected call of ValidateToken.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateToken(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateToken), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateTokenAndMergeWithUserInfo", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateTokenAndMergeWithUserInfo), arg0, arg1, arg2, arg3)
} }

View File

@ -97,7 +97,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
subject := DownstreamSubjectFromUpstreamOIDC(upstreamIssuer, upstreamSubject) subject := downstreamSubjectFromUpstreamOIDC(upstreamIssuer, upstreamSubject)
usernameClaimName := upstreamIDPConfig.GetUsernameClaim() usernameClaimName := upstreamIDPConfig.GetUsernameClaim()
if usernameClaimName == "" { if usernameClaimName == "" {
@ -176,7 +176,7 @@ func DownstreamLDAPSubject(uid string, ldapURL url.URL) string {
return ldapURL.String() return ldapURL.String()
} }
func DownstreamSubjectFromUpstreamOIDC(upstreamIssuerAsString string, upstreamSubject string) string { func downstreamSubjectFromUpstreamOIDC(upstreamIssuerAsString string, upstreamSubject string) string {
return fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, url.QueryEscape(upstreamSubject)) return fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, url.QueryEscape(upstreamSubject))
} }

View File

@ -74,7 +74,7 @@ type UpstreamOIDCIdentityProviderI interface {
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response // ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
// tokens, or an error. // tokens, or an error.
ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error)
} }
type UpstreamLDAPIdentityProviderI interface { type UpstreamLDAPIdentityProviderI interface {

View File

@ -128,7 +128,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at
// least some providers do not include one, so we skip the nonce validation here (but not other validations). // least some providers do not include one, so we skip the nonce validation here (but not other validations).
validatedTokens, err := p.ValidateToken(ctx, refreshedTokens, "", hasIDTok) validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, refreshedTokens, "", hasIDTok)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) "Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))

View File

@ -982,7 +982,6 @@ func TestRefreshGrant(t *testing.T) {
want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored) want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored)
// Should always try to perform an upstream refresh. // Should always try to perform an upstream refresh.
want.wantUpstreamRefreshCall = happyOIDCUpstreamRefreshCall() want.wantUpstreamRefreshCall = happyOIDCUpstreamRefreshCall()
// Should only try to ValidateToken when there was an id token returned by the upstream refresh.
if expectToValidateToken != nil { if expectToValidateToken != nil {
want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken) want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken)
} }
@ -1137,7 +1136,7 @@ func TestRefreshGrant(t *testing.T) {
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: happyRefreshTokenResponseForOpenIDAndOfflineAccess( want: happyRefreshTokenResponseForOpenIDAndOfflineAccess(
upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken),
refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), // expect ValidateToken is called refreshedUpstreamTokensWithRefreshTokenWithoutIDToken(), // expect ValidateTokenAndMergeWithUserInfo is called
), ),
}, },
}, },
@ -1592,7 +1591,7 @@ func TestRefreshGrant(t *testing.T) {
name: "when the upstream refresh returns an invalid ID token during the refresh request", name: "when the upstream refresh returns an invalid ID token during the refresh request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).
// This is the current format of the errors returned by the production code version of ValidateToken, see ValidateToken in upstreamoidc.go // This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go
WithValidateTokenError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))). WithValidateTokenError(httperr.Wrap(http.StatusBadRequest, "some validate error", errors.New("some validate cause"))).
Build()), Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
@ -1618,7 +1617,7 @@ func TestRefreshGrant(t *testing.T) {
name: "when the upstream refresh returns an ID token with a different subject than the original", name: "when the upstream refresh returns an ID token with a different subject than the original",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()). WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).
// This is the current format of the errors returned by the production code version of ValidateToken, see ValidateToken in upstreamoidc.go // This is the current format of the errors returned by the production code version of ValidateTokenAndMergeWithUserInfo, see ValidateTokenAndMergeWithUserInfo in upstreamoidc.go
WithValidatedTokens(&oidctypes.Token{ WithValidatedTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{ IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{ Claims: map[string]interface{}{

View File

@ -176,8 +176,6 @@ type TestUpstreamOIDCIdentityProvider struct {
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
ValidateRefreshFunc func(ctx context.Context, tok *oauth2.Token, storedAttributes provider.StoredRefreshAttributes) error
exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensCallCount int
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
passwordCredentialsGrantAndValidateTokensCallCount int passwordCredentialsGrantAndValidateTokensCallCount int
@ -188,8 +186,6 @@ type TestUpstreamOIDCIdentityProvider struct {
revokeRefreshTokenArgs []*RevokeRefreshTokenArgs revokeRefreshTokenArgs []*RevokeRefreshTokenArgs
validateTokenCallCount int validateTokenCallCount int
validateTokenArgs []*ValidateTokenArgs validateTokenArgs []*ValidateTokenArgs
validateRefreshCallCount int
validateRefreshArgs []*ValidateRefreshArgs
} }
var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{}
@ -288,19 +284,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
return u.PerformRefreshFunc(ctx, refreshToken) return u.PerformRefreshFunc(ctx, refreshToken)
} }
func (u *TestUpstreamOIDCIdentityProvider) ValidateRefresh(ctx context.Context, tok *oauth2.Token, storedAttributes provider.StoredRefreshAttributes) error {
if u.validateRefreshArgs == nil {
u.validateRefreshArgs = make([]*ValidateRefreshArgs, 0)
}
u.validateRefreshCallCount++
u.validateRefreshArgs = append(u.validateRefreshArgs, &ValidateRefreshArgs{
Ctx: ctx,
Tok: tok,
StoredAttributes: storedAttributes,
})
return u.ValidateRefreshFunc(ctx, tok, storedAttributes)
}
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshToken(ctx context.Context, refreshToken string) error { func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
if u.revokeRefreshTokenArgs == nil { if u.revokeRefreshTokenArgs == nil {
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0) u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0)
@ -335,7 +318,7 @@ func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *Rev
return u.revokeRefreshTokenArgs[call] return u.revokeRefreshTokenArgs[call]
} }
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) { func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) {
if u.validateTokenArgs == nil { if u.validateTokenArgs == nil {
u.validateTokenArgs = make([]*ValidateTokenArgs, 0) u.validateTokenArgs = make([]*ValidateTokenArgs, 0)
} }
@ -556,10 +539,10 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken(
} }
} }
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
"should have been exactly one call to ValidateToken() by all OIDC upstreams", "should have been exactly one call to ValidateTokenAndMergeWithUserInfo() by all OIDC upstreams",
) )
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"ValidateToken() was called on the wrong OIDC upstream", "ValidateTokenAndMergeWithUserInfo() was called on the wrong OIDC upstream",
) )
require.Equal(t, expectedArgs, actualArgs) require.Equal(t, expectedArgs, actualArgs)
} }
@ -571,7 +554,7 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount
} }
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
"expected exactly zero calls to ValidateToken()", "expected exactly zero calls to ValidateTokenAndMergeWithUserInfo()",
) )
} }

View File

@ -114,7 +114,7 @@ func (p *ProviderConfig) PasswordCredentialsGrantAndValidateTokens(ctx context.C
// There is no nonce to validate for a resource owner password credentials grant because it skips using // There is no nonce to validate for a resource owner password credentials grant because it skips using
// the authorize endpoint and goes straight to the token endpoint. // the authorize endpoint and goes straight to the token endpoint.
const skipNonceValidation nonce.Nonce = "" const skipNonceValidation nonce.Nonce = ""
return p.ValidateToken(ctx, tok, skipNonceValidation, true) return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, skipNonceValidation, true)
} }
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) { func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) {
@ -128,7 +128,7 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context,
return nil, err return nil, err
} }
return p.ValidateToken(ctx, tok, expectedIDTokenNonce, true) return p.ValidateTokenAndMergeWithUserInfo(ctx, tok, expectedIDTokenNonce, true)
} }
func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
@ -243,15 +243,17 @@ func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (st
return "", "", errors.New("downstream subject did not contain original upstream subject") return "", "", errors.New("downstream subject did not contain original upstream subject")
} }
split := strings.SplitN(downstreamSubject, "?sub=", 2) split := strings.SplitN(downstreamSubject, "?sub=", 2)
iss := split[0]
sub := split[1]
if iss == "" || sub == "" {
return "", "", errors.New("downstream subject was malformed")
}
return split[0], split[1], nil return split[0], split[1], nil
} }
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response, // ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response,
// if the provider offers the userinfo endpoint. // if the provider offers the userinfo endpoint.
// TODO check: func (p *ProviderConfig) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) {
// - whether the userinfo response must exist (maybe just based on whether there is a refresh token???
// - -> for the next story: userinfo has to exist but only if there isn't a refresh token.
func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool) (*oidctypes.Token, error) {
var validatedClaims = make(map[string]interface{}) var validatedClaims = make(map[string]interface{})
idTok, hasIDTok := tok.Extra("id_token").(string) idTok, hasIDTok := tok.Extra("id_token").(string)
@ -281,7 +283,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err) return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err)
} }
maybeLogClaims("claims from ID token", p.Name, validatedClaims) maybeLogClaims("claims from ID token", p.Name, validatedClaims)
idTokenExpiry = validated.Expiry idTokenExpiry = validated.Expiry // keep track of the id token expiry if we have an id token. Otherwise, it'll just be the zero value.
} }
idTokenSubject, _ := validatedClaims[oidc.IDTokenSubjectClaim].(string) idTokenSubject, _ := validatedClaims[oidc.IDTokenSubjectClaim].(string)
@ -320,31 +322,27 @@ func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, t
} }
// The sub (subject) Claim MUST always be returned in the UserInfo Response. // The sub (subject) Claim MUST always be returned in the UserInfo Response.
// However there may not be an id token. If there is an ID token, we must
// check it against the userinfo's subject.
//
// NOTE: Due to the possibility of token substitution attacks (see Section 16.11), the UserInfo Response is not // NOTE: Due to the possibility of token substitution attacks (see Section 16.11), the UserInfo Response is not
// guaranteed to be about the End-User identified by the sub (subject) element of the ID Token. The sub Claim in // guaranteed to be about the End-User identified by the sub (subject) element of the ID Token. The sub Claim in
// the UserInfo Response MUST be verified to exactly match the sub Claim in the ID Token; if they do not match, // the UserInfo Response MUST be verified to exactly match the sub Claim in the ID Token; if they do not match,
// the UserInfo Response values MUST NOT be used. // the UserInfo Response values MUST NOT be used.
// //
// http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse // http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
// If there is no ID token and it is not required, we must assume that the caller is performing other checks
// to ensure the subject is correct.
checkIDToken := requireIDToken || len(idTokenSubject) > 0 checkIDToken := requireIDToken || len(idTokenSubject) > 0
if checkIDToken && (len(userInfo.Subject) == 0 || userInfo.Subject != idTokenSubject) { if checkIDToken && (len(userInfo.Subject) == 0 || userInfo.Subject != idTokenSubject) {
return httperr.Newf(http.StatusUnprocessableEntity, "userinfo 'sub' claim (%s) did not match id_token 'sub' claim (%s)", userInfo.Subject, idTokenSubject) return httperr.Newf(http.StatusUnprocessableEntity, "userinfo 'sub' claim (%s) did not match id_token 'sub' claim (%s)", userInfo.Subject, idTokenSubject)
} }
if !checkIDToken { // keep track of the issuer from the ID token
claims["sub"] = userInfo.Subject // do this so other validations can check this subject later
}
idTokenIssuer := claims["iss"] idTokenIssuer := claims["iss"]
// merge existing claims with user info claims // merge existing claims with user info claims
if err := userInfo.Claims(&claims); err != nil { if err := userInfo.Claims(&claims); err != nil {
return httperr.Wrap(http.StatusInternalServerError, "could not unmarshal user info claims", err) return httperr.Wrap(http.StatusInternalServerError, "could not unmarshal user info claims", err)
} }
// The OIDC spec for user info response does not make any guarantees about the iss claim's existence or validity: // The OIDC spec for the UserInfo response does not make any guarantees about the iss claim's existence or validity:
// "If signed, the UserInfo Response SHOULD contain the Claims iss (issuer) and aud (audience) as members. The iss value SHOULD be the OP's Issuer Identifier URL." // "If signed, the UserInfo Response SHOULD contain the Claims iss (issuer) and aud (audience) as members. The iss value SHOULD be the OP's Issuer Identifier URL."
// See https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse // See https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
// So we just ignore it and use it the version from the id token, which has stronger guarantees. // So we just ignore it and use it the version from the id token, which has stronger guarantees.

View File

@ -1,4 +1,4 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package upstreamoidc package upstreamoidc
@ -589,7 +589,7 @@ func TestProviderConfig(t *testing.T) {
} }
}) })
t.Run("ValidateToken", func(t *testing.T) { t.Run("ValidateTokenAndMergeWithUserInfo", func(t *testing.T) {
expiryTime := time.Now().Add(42 * time.Second) expiryTime := time.Now().Add(42 * time.Second)
testTokenWithoutIDToken := &oauth2.Token{ testTokenWithoutIDToken := &oauth2.Token{
AccessToken: "test-access-token", AccessToken: "test-access-token",
@ -646,7 +646,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: false, requireIDToken: false,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -673,7 +673,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -700,7 +700,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -727,7 +727,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "iss": "some-other-issuer"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "iss": "some-other-issuer", "sub": "some-subject"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -754,7 +754,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "", nonce: "",
requireIDToken: false, requireIDToken: false,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "iss": "some-other-issuer"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "iss": "some-other-issuer", "sub": "some-subject"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -799,7 +799,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal", "sub": "some-other-subject"}`),
wantErr: "could not fetch user info claims: userinfo 'sub' claim (some-other-subject) did not match id_token 'sub' claim (some-subject)", wantErr: "could not fetch user info claims: userinfo 'sub' claim (some-other-subject) did not match id_token 'sub' claim (some-subject)",
}, },
{ {
@ -808,7 +808,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: false, requireIDToken: false,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal", "sub": "some-other-subject"}`),
wantErr: "could not fetch user info claims: userinfo 'sub' claim (some-other-subject) did not match id_token 'sub' claim (some-subject)", wantErr: "could not fetch user info claims: userinfo 'sub' claim (some-other-subject) did not match id_token 'sub' claim (some-subject)",
}, },
{ {
@ -817,7 +817,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal", "sub": "some-other-subject"}`),
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
}, },
{ {
@ -826,7 +826,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-other-nonce", nonce: "some-other-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal", "sub": "some-other-subject"}`),
wantErr: "received ID token with invalid nonce: invalid nonce (expected \"some-other-nonce\", got \"some-nonce\")", wantErr: "received ID token with invalid nonce: invalid nonce (expected \"some-other-nonce\", got \"some-nonce\")",
}, },
{ {
@ -835,7 +835,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-other-nonce", nonce: "some-other-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantErr: "received response missing ID token", wantErr: "received response missing ID token",
}, },
{ {
@ -844,7 +844,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-other-nonce", nonce: "some-other-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-subject", `{"name": "Pinny TheSeal", "sub": "some-subject"}`),
wantErr: "received response missing ID token", wantErr: "received response missing ID token",
}, },
{ {
@ -853,7 +853,7 @@ func TestProviderConfig(t *testing.T) {
nonce: "some-nonce", nonce: "some-nonce",
requireIDToken: true, requireIDToken: true,
rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`), rawClaims: []byte(`{"userinfo_endpoint": "not-empty"}`),
userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal"}`), userInfo: forceUserInfoWithClaims("some-other-subject", `{"name": "Pinny TheSeal", "sub": "some-other-subject"}`),
wantMergedTokens: &oidctypes.Token{ wantMergedTokens: &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -898,7 +898,7 @@ func TestProviderConfig(t *testing.T) {
userInfoErr: tt.userInfoErr, userInfoErr: tt.userInfoErr,
}, },
} }
gotTok, err := p.ValidateToken(context.Background(), tt.tok, tt.nonce, tt.requireIDToken) gotTok, err := p.ValidateTokenAndMergeWithUserInfo(context.Background(), tt.tok, tt.nonce, tt.requireIDToken)
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(t, err) require.Error(t, err)
require.Equal(t, tt.wantErr, err.Error()) require.Equal(t, tt.wantErr, err.Error())
@ -924,24 +924,27 @@ func TestProviderConfig(t *testing.T) {
wantUpstreamSubject: "some-subject", wantUpstreamSubject: "some-subject",
wantUpstreamIssuer: "https://some-issuer", wantUpstreamIssuer: "https://some-issuer",
}, },
{
name: "happy path but sub is empty string", // todo i think this should not be the responsibility of this function, even though it's undesirable behavior...
downstreamSubject: "https://some-issuer?sub=",
wantUpstreamSubject: "",
wantUpstreamIssuer: "https://some-issuer",
},
{
name: "happy path but iss is empty string",
downstreamSubject: "?sub=some-subject",
wantUpstreamSubject: "some-subject",
wantUpstreamIssuer: "",
},
{ {
name: "subject in a subject", name: "subject in a subject",
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject", downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
wantUpstreamSubject: "https://some-issuer?sub=some-subject", wantUpstreamSubject: "https://some-issuer?sub=some-subject",
wantUpstreamIssuer: "https://some-other-issuer", wantUpstreamIssuer: "https://some-other-issuer",
}, },
{
name: "sub is empty string",
downstreamSubject: "https://some-issuer?sub=",
wantErr: "downstream subject was malformed",
},
{
name: "iss is empty string",
downstreamSubject: "?sub=some-subject",
wantErr: "downstream subject was malformed",
},
{
name: "empty string",
downstreamSubject: "",
wantErr: "downstream subject did not contain original upstream subject",
},
{ {
name: "doesn't contain sub=", name: "doesn't contain sub=",
downstreamSubject: "something-invalid", downstreamSubject: "something-invalid",

View File

@ -822,7 +822,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations). // some providers do not include one, so we skip the nonce validation here (but not other validations).
return upstreamOIDCIdentityProvider.ValidateToken(ctx, refreshed, "", true) return upstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfo(ctx, refreshed, "", true)
} }
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {