WIP: Adjust subject and username claims

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Margo Crawford 2020-12-14 17:05:53 -08:00 committed by Ryan Richard
parent a5c07042c1
commit afcd5e3e36
2 changed files with 75 additions and 46 deletions

View File

@ -30,9 +30,10 @@ const (
// The name of the subject claim specified in the OIDC spec. // The name of the subject claim specified in the OIDC spec.
idTokenSubjectClaim = "sub" idTokenSubjectClaim = "sub"
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC // idTokenUsernameClaim is a custom claim in the downstream ID token
// ID token if the upstream OIDC IDP did not tell us to use another claim. // whose value is mapped from a claim in the upstream token.
defaultUpstreamUsernameClaim = idTokenSubjectClaim // By default the value is the same as the downstream subject claim's.
idTokenUsernameClaim = "username"
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information. // information.
@ -88,7 +89,7 @@ func NewHandler(
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
} }
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
if err != nil { if err != nil {
return err return err
} }
@ -98,7 +99,7 @@ func NewHandler(
return err return err
} }
openIDSession := makeDownstreamSession(username, groups) openIDSession := makeDownstreamSession(subject, username, groups)
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
if err != nil { if err != nil {
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName()) plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
@ -192,16 +193,12 @@ func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateP
return &state, nil return &state, nil
} }
func getUsernameFromUpstreamIDToken( func getSubjectAndUsernameFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{}, idTokenClaims map[string]interface{},
) (string, error) { ) (string, string, error) {
usernameClaim := upstreamIDPConfig.GetUsernameClaim() // The spec says the "sub" claim is only unique per issuer,
// so we will prepend the issuer string to make it globally unique.
user := ""
if usernameClaim == "" {
// The spec says the "sub" claim is only unique per issuer, so by default when there is
// no specific username claim configured we will prepend the issuer string to make it globally unique.
upstreamIssuer := idTokenClaims[idTokenIssuerClaim] upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
if upstreamIssuer == "" { if upstreamIssuer == "" {
plog.Warning( plog.Warning(
@ -209,7 +206,7 @@ func getUsernameFromUpstreamIDToken(
"upstreamName", upstreamIDPConfig.GetName(), "upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer, "issClaim", upstreamIssuer,
) )
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing") return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
} }
upstreamIssuerAsString, ok := upstreamIssuer.(string) upstreamIssuerAsString, ok := upstreamIssuer.(string)
if !ok { if !ok {
@ -218,10 +215,32 @@ func getUsernameFromUpstreamIDToken(
"upstreamName", upstreamIDPConfig.GetName(), "upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer, "issClaim", upstreamIssuer,
) )
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
} }
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
usernameClaim = defaultUpstreamUsernameClaim subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim]
if !ok {
plog.Warning(
"no subject claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
)
return "", "", httperr.New(http.StatusUnprocessableEntity, "no subject claim in upstream ID token")
}
upstreamSubject, ok := subjectAsInterface.(string)
if !ok {
plog.Warning(
"subject claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
)
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
}
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject)
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
if usernameClaim == "" {
return subject, subject, nil
} }
usernameAsInterface, ok := idTokenClaims[usernameClaim] usernameAsInterface, ok := idTokenClaims[usernameClaim]
@ -232,7 +251,7 @@ func getUsernameFromUpstreamIDToken(
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim, "usernameClaim", usernameClaim,
) )
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token") return "", "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
} }
username, ok := usernameAsInterface.(string) username, ok := usernameAsInterface.(string)
@ -243,10 +262,10 @@ func getUsernameFromUpstreamIDToken(
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim, "usernameClaim", usernameClaim,
) )
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format") return "", "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
} }
return fmt.Sprintf("%s%s", user, username), nil return subject, username, nil
} }
func getGroupsFromUpstreamIDToken( func getGroupsFromUpstreamIDToken(
@ -283,19 +302,20 @@ func getGroupsFromUpstreamIDToken(
return groups, nil return groups, nil
} }
func makeDownstreamSession(username string, groups []string) *openid.DefaultSession { func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession {
now := time.Now().UTC() now := time.Now().UTC()
openIDSession := &openid.DefaultSession{ openIDSession := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{
Subject: username, Subject: subject,
RequestedAt: now, RequestedAt: now,
AuthTime: now, AuthTime: now,
}, },
} }
if groups != nil {
openIDSession.Claims.Extra = map[string]interface{}{ openIDSession.Claims.Extra = map[string]interface{}{
downstreamGroupsClaim: groups, idTokenUsernameClaim: username,
} }
if groups != nil {
openIDSession.Claims.Extra[downstreamGroupsClaim] = groups
} }
return openIDSession return openIDSession
} }

View File

@ -133,6 +133,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantRedirectLocationRegexp string wantRedirectLocationRegexp string
wantDownstreamGrantedScopes []string wantDownstreamGrantedScopes []string
wantDownstreamIDTokenSubject string wantDownstreamIDTokenSubject string
wantDownstreamIDTokenUsername string
wantDownstreamIDTokenGroups []string wantDownstreamIDTokenGroups []string
wantDownstreamRequestedScopes []string wantDownstreamRequestedScopes []string
wantDownstreamNonce string wantDownstreamNonce string
@ -150,7 +151,8 @@ func TestCallbackEndpoint(t *testing.T) {
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "", wantBody: "",
wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
@ -169,6 +171,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "", wantBody: "",
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenGroups: nil, wantDownstreamIDTokenGroups: nil,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
@ -186,7 +189,8 @@ func TestCallbackEndpoint(t *testing.T) {
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "", wantBody: "",
wantDownstreamIDTokenSubject: upstreamSubject, wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamSubject,
wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
@ -312,7 +316,8 @@ func TestCallbackEndpoint(t *testing.T) {
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamRequestedScopes: []string{"profile", "email"}, wantDownstreamRequestedScopes: []string{"profile", "email"},
wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamNonce: downstreamNonce, wantDownstreamNonce: downstreamNonce,
@ -333,7 +338,8 @@ func TestCallbackEndpoint(t *testing.T) {
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername, wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamRequestedScopes: []string{"openid", "offline_access"}, wantDownstreamRequestedScopes: []string{"openid", "offline_access"},
wantDownstreamGrantedScopes: []string{"openid", "offline_access"}, wantDownstreamGrantedScopes: []string{"openid", "offline_access"},
wantDownstreamIDTokenGroups: upstreamGroupMembership, wantDownstreamIDTokenGroups: upstreamGroupMembership,
@ -522,6 +528,7 @@ func TestCallbackEndpoint(t *testing.T) {
authcodeDataAndSignature[1], // Authcode store key is authcode signature authcodeDataAndSignature[1], // Authcode store key is authcode signature
test.wantDownstreamGrantedScopes, test.wantDownstreamGrantedScopes,
test.wantDownstreamIDTokenSubject, test.wantDownstreamIDTokenSubject,
test.wantDownstreamIDTokenUsername,
test.wantDownstreamIDTokenGroups, test.wantDownstreamIDTokenGroups,
test.wantDownstreamRequestedScopes, test.wantDownstreamRequestedScopes,
) )
@ -742,6 +749,7 @@ func validateAuthcodeStorage(
storeKey string, storeKey string,
wantDownstreamGrantedScopes []string, wantDownstreamGrantedScopes []string,
wantDownstreamIDTokenSubject string, wantDownstreamIDTokenSubject string,
wantDownstreamIDTokenUsername string,
wantDownstreamIDTokenGroups []string, wantDownstreamIDTokenGroups []string,
wantDownstreamRequestedScopes []string, wantDownstreamRequestedScopes []string,
) (*fosite.Request, *openid.DefaultSession) { ) (*fosite.Request, *openid.DefaultSession) {
@ -778,13 +786,14 @@ func validateAuthcodeStorage(
// Now confirm the ID token claims. // Now confirm the ID token claims.
actualClaims := storedSessionFromAuthcode.Claims actualClaims := storedSessionFromAuthcode.Claims
// Check the user's identity, which are put into the downstream ID token's subject and groups claims. // Check the user's identity, which are put into the downstream ID token's subject, username and groups claims.
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"])
if wantDownstreamIDTokenGroups != nil { if wantDownstreamIDTokenGroups != nil {
require.Len(t, actualClaims.Extra, 1) require.Len(t, actualClaims.Extra, 2)
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
} else { } else {
require.Empty(t, actualClaims.Extra) require.Len(t, actualClaims.Extra, 1)
require.NotContains(t, actualClaims.Extra, "groups") require.NotContains(t, actualClaims.Extra, "groups")
} }