diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 2cea6c2e..3bdcd9ee 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -30,9 +30,10 @@ const ( // The name of the subject claim specified in the OIDC spec. idTokenSubjectClaim = "sub" - // defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC - // ID token if the upstream OIDC IDP did not tell us to use another claim. - defaultUpstreamUsernameClaim = idTokenSubjectClaim + // idTokenUsernameClaim is a custom claim in the downstream ID token + // whose value is mapped from a claim in the upstream token. + // By default the value is the same as the downstream subject claim's. + idTokenUsernameClaim = "username" // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token // information. @@ -88,7 +89,7 @@ func NewHandler( return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } - username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) + subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) if err != nil { return err } @@ -98,7 +99,7 @@ func NewHandler( return err } - openIDSession := makeDownstreamSession(username, groups) + openIDSession := makeDownstreamSession(subject, username, groups) authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) if err != nil { plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName()) @@ -192,36 +193,54 @@ func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateP return &state, nil } -func getUsernameFromUpstreamIDToken( +func getSubjectAndUsernameFromUpstreamIDToken( upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, idTokenClaims map[string]interface{}, -) (string, error) { - usernameClaim := upstreamIDPConfig.GetUsernameClaim() +) (string, string, error) { + // The spec says the "sub" claim is only unique per issuer, + // so we will prepend the issuer string to make it globally unique. + upstreamIssuer := idTokenClaims[idTokenIssuerClaim] + if upstreamIssuer == "" { + plog.Warning( + "issuer claim in upstream ID token missing", + "upstreamName", upstreamIDPConfig.GetName(), + "issClaim", upstreamIssuer, + ) + return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing") + } + upstreamIssuerAsString, ok := upstreamIssuer.(string) + if !ok { + plog.Warning( + "issuer claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "issClaim", upstreamIssuer, + ) + return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") + } - user := "" + subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim] + if !ok { + plog.Warning( + "no subject claim in upstream ID token", + "upstreamName", upstreamIDPConfig.GetName(), + ) + return "", "", httperr.New(http.StatusUnprocessableEntity, "no subject claim in upstream ID token") + } + + upstreamSubject, ok := subjectAsInterface.(string) + if !ok { + plog.Warning( + "subject claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + ) + return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format") + } + + subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject) + + usernameClaim := upstreamIDPConfig.GetUsernameClaim() if usernameClaim == "" { - // The spec says the "sub" claim is only unique per issuer, so by default when there is - // no specific username claim configured we will prepend the issuer string to make it globally unique. - upstreamIssuer := idTokenClaims[idTokenIssuerClaim] - if upstreamIssuer == "" { - plog.Warning( - "issuer claim in upstream ID token missing", - "upstreamName", upstreamIDPConfig.GetName(), - "issClaim", upstreamIssuer, - ) - return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing") - } - upstreamIssuerAsString, ok := upstreamIssuer.(string) - if !ok { - plog.Warning( - "issuer claim in upstream ID token has invalid format", - "upstreamName", upstreamIDPConfig.GetName(), - "issClaim", upstreamIssuer, - ) - return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") - } - user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim) - usernameClaim = defaultUpstreamUsernameClaim + return subject, subject, nil } usernameAsInterface, ok := idTokenClaims[usernameClaim] @@ -232,7 +251,7 @@ func getUsernameFromUpstreamIDToken( "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), "usernameClaim", usernameClaim, ) - return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token") + return "", "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token") } username, ok := usernameAsInterface.(string) @@ -243,10 +262,10 @@ func getUsernameFromUpstreamIDToken( "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), "usernameClaim", usernameClaim, ) - return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format") + return "", "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format") } - return fmt.Sprintf("%s%s", user, username), nil + return subject, username, nil } func getGroupsFromUpstreamIDToken( @@ -283,19 +302,20 @@ func getGroupsFromUpstreamIDToken( return groups, nil } -func makeDownstreamSession(username string, groups []string) *openid.DefaultSession { +func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession { now := time.Now().UTC() openIDSession := &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ - Subject: username, + Subject: subject, RequestedAt: now, AuthTime: now, }, } + openIDSession.Claims.Extra = map[string]interface{}{ + idTokenUsernameClaim: username, + } if groups != nil { - openIDSession.Claims.Extra = map[string]interface{}{ - downstreamGroupsClaim: groups, - } + openIDSession.Claims.Extra[downstreamGroupsClaim] = groups } return openIDSession } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 38d12b29..eb7e4eea 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -133,6 +133,7 @@ func TestCallbackEndpoint(t *testing.T) { wantRedirectLocationRegexp string wantDownstreamGrantedScopes []string wantDownstreamIDTokenSubject string + wantDownstreamIDTokenUsername string wantDownstreamIDTokenGroups []string wantDownstreamRequestedScopes []string wantDownstreamNonce string @@ -150,7 +151,8 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, + wantDownstreamIDTokenUsername: upstreamUsername, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, @@ -169,6 +171,7 @@ func TestCallbackEndpoint(t *testing.T) { wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, + wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamIDTokenGroups: nil, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, @@ -186,7 +189,8 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamSubject, + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, + wantDownstreamIDTokenUsername: upstreamSubject, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, @@ -312,7 +316,8 @@ func TestCallbackEndpoint(t *testing.T) { csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, - wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamRequestedScopes: []string{"profile", "email"}, wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamNonce: downstreamNonce, @@ -333,7 +338,8 @@ func TestCallbackEndpoint(t *testing.T) { csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState, - wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamRequestedScopes: []string{"openid", "offline_access"}, wantDownstreamGrantedScopes: []string{"openid", "offline_access"}, wantDownstreamIDTokenGroups: upstreamGroupMembership, @@ -522,6 +528,7 @@ func TestCallbackEndpoint(t *testing.T) { authcodeDataAndSignature[1], // Authcode store key is authcode signature test.wantDownstreamGrantedScopes, test.wantDownstreamIDTokenSubject, + test.wantDownstreamIDTokenUsername, test.wantDownstreamIDTokenGroups, test.wantDownstreamRequestedScopes, ) @@ -742,6 +749,7 @@ func validateAuthcodeStorage( storeKey string, wantDownstreamGrantedScopes []string, wantDownstreamIDTokenSubject string, + wantDownstreamIDTokenUsername string, wantDownstreamIDTokenGroups []string, wantDownstreamRequestedScopes []string, ) (*fosite.Request, *openid.DefaultSession) { @@ -778,13 +786,14 @@ func validateAuthcodeStorage( // Now confirm the ID token claims. actualClaims := storedSessionFromAuthcode.Claims - // Check the user's identity, which are put into the downstream ID token's subject and groups claims. + // Check the user's identity, which are put into the downstream ID token's subject, username and groups claims. require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) + require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"]) if wantDownstreamIDTokenGroups != nil { - require.Len(t, actualClaims.Extra, 1) + require.Len(t, actualClaims.Extra, 2) require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) } else { - require.Empty(t, actualClaims.Extra) + require.Len(t, actualClaims.Extra, 1) require.NotContains(t, actualClaims.Extra, "groups") }