WIP: Adjust subject and username claims
Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
a5c07042c1
commit
afcd5e3e36
@ -30,9 +30,10 @@ const (
|
||||
// The name of the subject claim specified in the OIDC spec.
|
||||
idTokenSubjectClaim = "sub"
|
||||
|
||||
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
|
||||
// ID token if the upstream OIDC IDP did not tell us to use another claim.
|
||||
defaultUpstreamUsernameClaim = idTokenSubjectClaim
|
||||
// idTokenUsernameClaim is a custom claim in the downstream ID token
|
||||
// whose value is mapped from a claim in the upstream token.
|
||||
// By default the value is the same as the downstream subject claim's.
|
||||
idTokenUsernameClaim = "username"
|
||||
|
||||
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
|
||||
// information.
|
||||
@ -88,7 +89,7 @@ func NewHandler(
|
||||
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
|
||||
}
|
||||
|
||||
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
|
||||
subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -98,7 +99,7 @@ func NewHandler(
|
||||
return err
|
||||
}
|
||||
|
||||
openIDSession := makeDownstreamSession(username, groups)
|
||||
openIDSession := makeDownstreamSession(subject, username, groups)
|
||||
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
|
||||
if err != nil {
|
||||
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
|
||||
@ -192,36 +193,54 @@ func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateP
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func getUsernameFromUpstreamIDToken(
|
||||
func getSubjectAndUsernameFromUpstreamIDToken(
|
||||
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
|
||||
idTokenClaims map[string]interface{},
|
||||
) (string, error) {
|
||||
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
||||
) (string, string, error) {
|
||||
// The spec says the "sub" claim is only unique per issuer,
|
||||
// so we will prepend the issuer string to make it globally unique.
|
||||
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
|
||||
if upstreamIssuer == "" {
|
||||
plog.Warning(
|
||||
"issuer claim in upstream ID token missing",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"issClaim", upstreamIssuer,
|
||||
)
|
||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
|
||||
}
|
||||
upstreamIssuerAsString, ok := upstreamIssuer.(string)
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"issuer claim in upstream ID token has invalid format",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"issClaim", upstreamIssuer,
|
||||
)
|
||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
||||
}
|
||||
|
||||
user := ""
|
||||
subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim]
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"no subject claim in upstream ID token",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
)
|
||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "no subject claim in upstream ID token")
|
||||
}
|
||||
|
||||
upstreamSubject, ok := subjectAsInterface.(string)
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"subject claim in upstream ID token has invalid format",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
)
|
||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
|
||||
}
|
||||
|
||||
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject)
|
||||
|
||||
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
||||
if usernameClaim == "" {
|
||||
// The spec says the "sub" claim is only unique per issuer, so by default when there is
|
||||
// no specific username claim configured we will prepend the issuer string to make it globally unique.
|
||||
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
|
||||
if upstreamIssuer == "" {
|
||||
plog.Warning(
|
||||
"issuer claim in upstream ID token missing",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"issClaim", upstreamIssuer,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
|
||||
}
|
||||
upstreamIssuerAsString, ok := upstreamIssuer.(string)
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"issuer claim in upstream ID token has invalid format",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"issClaim", upstreamIssuer,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
||||
}
|
||||
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
|
||||
usernameClaim = defaultUpstreamUsernameClaim
|
||||
return subject, subject, nil
|
||||
}
|
||||
|
||||
usernameAsInterface, ok := idTokenClaims[usernameClaim]
|
||||
@ -232,7 +251,7 @@ func getUsernameFromUpstreamIDToken(
|
||||
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
||||
"usernameClaim", usernameClaim,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
|
||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
|
||||
}
|
||||
|
||||
username, ok := usernameAsInterface.(string)
|
||||
@ -243,10 +262,10 @@ func getUsernameFromUpstreamIDToken(
|
||||
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
||||
"usernameClaim", usernameClaim,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
|
||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", user, username), nil
|
||||
return subject, username, nil
|
||||
}
|
||||
|
||||
func getGroupsFromUpstreamIDToken(
|
||||
@ -283,19 +302,20 @@ func getGroupsFromUpstreamIDToken(
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func makeDownstreamSession(username string, groups []string) *openid.DefaultSession {
|
||||
func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession {
|
||||
now := time.Now().UTC()
|
||||
openIDSession := &openid.DefaultSession{
|
||||
Claims: &jwt.IDTokenClaims{
|
||||
Subject: username,
|
||||
Subject: subject,
|
||||
RequestedAt: now,
|
||||
AuthTime: now,
|
||||
},
|
||||
}
|
||||
openIDSession.Claims.Extra = map[string]interface{}{
|
||||
idTokenUsernameClaim: username,
|
||||
}
|
||||
if groups != nil {
|
||||
openIDSession.Claims.Extra = map[string]interface{}{
|
||||
downstreamGroupsClaim: groups,
|
||||
}
|
||||
openIDSession.Claims.Extra[downstreamGroupsClaim] = groups
|
||||
}
|
||||
return openIDSession
|
||||
}
|
||||
|
@ -133,6 +133,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
wantRedirectLocationRegexp string
|
||||
wantDownstreamGrantedScopes []string
|
||||
wantDownstreamIDTokenSubject string
|
||||
wantDownstreamIDTokenUsername string
|
||||
wantDownstreamIDTokenGroups []string
|
||||
wantDownstreamRequestedScopes []string
|
||||
wantDownstreamNonce string
|
||||
@ -150,7 +151,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||
wantBody: "",
|
||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamIDTokenUsername: upstreamUsername,
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
||||
@ -169,6 +171,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||
wantBody: "",
|
||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamIDTokenGroups: nil,
|
||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
||||
@ -186,7 +189,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||
wantBody: "",
|
||||
wantDownstreamIDTokenSubject: upstreamSubject,
|
||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamIDTokenUsername: upstreamSubject,
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
||||
@ -312,7 +316,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
|
||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
||||
wantDownstreamIDTokenUsername: upstreamUsername,
|
||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamRequestedScopes: []string{"profile", "email"},
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
wantDownstreamNonce: downstreamNonce,
|
||||
@ -333,7 +338,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState,
|
||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
||||
wantDownstreamIDTokenUsername: upstreamUsername,
|
||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamRequestedScopes: []string{"openid", "offline_access"},
|
||||
wantDownstreamGrantedScopes: []string{"openid", "offline_access"},
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
@ -522,6 +528,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
authcodeDataAndSignature[1], // Authcode store key is authcode signature
|
||||
test.wantDownstreamGrantedScopes,
|
||||
test.wantDownstreamIDTokenSubject,
|
||||
test.wantDownstreamIDTokenUsername,
|
||||
test.wantDownstreamIDTokenGroups,
|
||||
test.wantDownstreamRequestedScopes,
|
||||
)
|
||||
@ -742,6 +749,7 @@ func validateAuthcodeStorage(
|
||||
storeKey string,
|
||||
wantDownstreamGrantedScopes []string,
|
||||
wantDownstreamIDTokenSubject string,
|
||||
wantDownstreamIDTokenUsername string,
|
||||
wantDownstreamIDTokenGroups []string,
|
||||
wantDownstreamRequestedScopes []string,
|
||||
) (*fosite.Request, *openid.DefaultSession) {
|
||||
@ -778,13 +786,14 @@ func validateAuthcodeStorage(
|
||||
// Now confirm the ID token claims.
|
||||
actualClaims := storedSessionFromAuthcode.Claims
|
||||
|
||||
// Check the user's identity, which are put into the downstream ID token's subject and groups claims.
|
||||
// Check the user's identity, which are put into the downstream ID token's subject, username and groups claims.
|
||||
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
|
||||
require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"])
|
||||
if wantDownstreamIDTokenGroups != nil {
|
||||
require.Len(t, actualClaims.Extra, 1)
|
||||
require.Len(t, actualClaims.Extra, 2)
|
||||
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
|
||||
} else {
|
||||
require.Empty(t, actualClaims.Extra)
|
||||
require.Len(t, actualClaims.Extra, 1)
|
||||
require.NotContains(t, actualClaims.Extra, "groups")
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user