WIP: Adjust subject and username claims

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Margo Crawford 2020-12-14 17:05:53 -08:00 committed by Ryan Richard
parent a5c07042c1
commit afcd5e3e36
2 changed files with 75 additions and 46 deletions

View File

@ -30,9 +30,10 @@ const (
// The name of the subject claim specified in the OIDC spec.
idTokenSubjectClaim = "sub"
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
// ID token if the upstream OIDC IDP did not tell us to use another claim.
defaultUpstreamUsernameClaim = idTokenSubjectClaim
// idTokenUsernameClaim is a custom claim in the downstream ID token
// whose value is mapped from a claim in the upstream token.
// By default the value is the same as the downstream subject claim's.
idTokenUsernameClaim = "username"
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
@ -88,7 +89,7 @@ func NewHandler(
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
}
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
if err != nil {
return err
}
@ -98,7 +99,7 @@ func NewHandler(
return err
}
openIDSession := makeDownstreamSession(username, groups)
openIDSession := makeDownstreamSession(subject, username, groups)
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
if err != nil {
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
@ -192,16 +193,12 @@ func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateP
return &state, nil
}
func getUsernameFromUpstreamIDToken(
func getSubjectAndUsernameFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) (string, error) {
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
user := ""
if usernameClaim == "" {
// The spec says the "sub" claim is only unique per issuer, so by default when there is
// no specific username claim configured we will prepend the issuer string to make it globally unique.
) (string, string, error) {
// The spec says the "sub" claim is only unique per issuer,
// so we will prepend the issuer string to make it globally unique.
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
if upstreamIssuer == "" {
plog.Warning(
@ -209,7 +206,7 @@ func getUsernameFromUpstreamIDToken(
"upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer,
)
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
}
upstreamIssuerAsString, ok := upstreamIssuer.(string)
if !ok {
@ -218,10 +215,32 @@ func getUsernameFromUpstreamIDToken(
"upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer,
)
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
}
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
usernameClaim = defaultUpstreamUsernameClaim
subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim]
if !ok {
plog.Warning(
"no subject claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
)
return "", "", httperr.New(http.StatusUnprocessableEntity, "no subject claim in upstream ID token")
}
upstreamSubject, ok := subjectAsInterface.(string)
if !ok {
plog.Warning(
"subject claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
)
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
}
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject)
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
if usernameClaim == "" {
return subject, subject, nil
}
usernameAsInterface, ok := idTokenClaims[usernameClaim]
@ -232,7 +251,7 @@ func getUsernameFromUpstreamIDToken(
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim,
)
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
return "", "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
}
username, ok := usernameAsInterface.(string)
@ -243,10 +262,10 @@ func getUsernameFromUpstreamIDToken(
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim,
)
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
return "", "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
}
return fmt.Sprintf("%s%s", user, username), nil
return subject, username, nil
}
func getGroupsFromUpstreamIDToken(
@ -283,19 +302,20 @@ func getGroupsFromUpstreamIDToken(
return groups, nil
}
func makeDownstreamSession(username string, groups []string) *openid.DefaultSession {
func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession {
now := time.Now().UTC()
openIDSession := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: username,
Subject: subject,
RequestedAt: now,
AuthTime: now,
},
}
if groups != nil {
openIDSession.Claims.Extra = map[string]interface{}{
downstreamGroupsClaim: groups,
idTokenUsernameClaim: username,
}
if groups != nil {
openIDSession.Claims.Extra[downstreamGroupsClaim] = groups
}
return openIDSession
}

View File

@ -133,6 +133,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantRedirectLocationRegexp string
wantDownstreamGrantedScopes []string
wantDownstreamIDTokenSubject string
wantDownstreamIDTokenUsername string
wantDownstreamIDTokenGroups []string
wantDownstreamRequestedScopes []string
wantDownstreamNonce string
@ -150,7 +151,8 @@ func TestCallbackEndpoint(t *testing.T) {
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
@ -169,6 +171,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenGroups: nil,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
@ -186,7 +189,8 @@ func TestCallbackEndpoint(t *testing.T) {
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamSubject,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamSubject,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
@ -312,7 +316,8 @@ func TestCallbackEndpoint(t *testing.T) {
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamRequestedScopes: []string{"profile", "email"},
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamNonce: downstreamNonce,
@ -333,7 +338,8 @@ func TestCallbackEndpoint(t *testing.T) {
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamRequestedScopes: []string{"openid", "offline_access"},
wantDownstreamGrantedScopes: []string{"openid", "offline_access"},
wantDownstreamIDTokenGroups: upstreamGroupMembership,
@ -522,6 +528,7 @@ func TestCallbackEndpoint(t *testing.T) {
authcodeDataAndSignature[1], // Authcode store key is authcode signature
test.wantDownstreamGrantedScopes,
test.wantDownstreamIDTokenSubject,
test.wantDownstreamIDTokenUsername,
test.wantDownstreamIDTokenGroups,
test.wantDownstreamRequestedScopes,
)
@ -742,6 +749,7 @@ func validateAuthcodeStorage(
storeKey string,
wantDownstreamGrantedScopes []string,
wantDownstreamIDTokenSubject string,
wantDownstreamIDTokenUsername string,
wantDownstreamIDTokenGroups []string,
wantDownstreamRequestedScopes []string,
) (*fosite.Request, *openid.DefaultSession) {
@ -778,13 +786,14 @@ func validateAuthcodeStorage(
// Now confirm the ID token claims.
actualClaims := storedSessionFromAuthcode.Claims
// Check the user's identity, which are put into the downstream ID token's subject and groups claims.
// Check the user's identity, which are put into the downstream ID token's subject, username and groups claims.
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"])
if wantDownstreamIDTokenGroups != nil {
require.Len(t, actualClaims.Extra, 1)
require.Len(t, actualClaims.Extra, 2)
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
} else {
require.Empty(t, actualClaims.Extra)
require.Len(t, actualClaims.Extra, 1)
require.NotContains(t, actualClaims.Extra, "groups")
}