WIP: Adjust subject and username claims
Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
a5c07042c1
commit
afcd5e3e36
@ -30,9 +30,10 @@ const (
|
|||||||
// The name of the subject claim specified in the OIDC spec.
|
// The name of the subject claim specified in the OIDC spec.
|
||||||
idTokenSubjectClaim = "sub"
|
idTokenSubjectClaim = "sub"
|
||||||
|
|
||||||
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
|
// idTokenUsernameClaim is a custom claim in the downstream ID token
|
||||||
// ID token if the upstream OIDC IDP did not tell us to use another claim.
|
// whose value is mapped from a claim in the upstream token.
|
||||||
defaultUpstreamUsernameClaim = idTokenSubjectClaim
|
// By default the value is the same as the downstream subject claim's.
|
||||||
|
idTokenUsernameClaim = "username"
|
||||||
|
|
||||||
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
|
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
|
||||||
// information.
|
// information.
|
||||||
@ -88,7 +89,7 @@ func NewHandler(
|
|||||||
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
|
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
|
||||||
}
|
}
|
||||||
|
|
||||||
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
|
subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -98,7 +99,7 @@ func NewHandler(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
openIDSession := makeDownstreamSession(username, groups)
|
openIDSession := makeDownstreamSession(subject, username, groups)
|
||||||
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
|
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
|
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
|
||||||
@ -192,16 +193,12 @@ func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateP
|
|||||||
return &state, nil
|
return &state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUsernameFromUpstreamIDToken(
|
func getSubjectAndUsernameFromUpstreamIDToken(
|
||||||
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
|
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
|
||||||
idTokenClaims map[string]interface{},
|
idTokenClaims map[string]interface{},
|
||||||
) (string, error) {
|
) (string, string, error) {
|
||||||
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
// The spec says the "sub" claim is only unique per issuer,
|
||||||
|
// so we will prepend the issuer string to make it globally unique.
|
||||||
user := ""
|
|
||||||
if usernameClaim == "" {
|
|
||||||
// The spec says the "sub" claim is only unique per issuer, so by default when there is
|
|
||||||
// no specific username claim configured we will prepend the issuer string to make it globally unique.
|
|
||||||
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
|
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
|
||||||
if upstreamIssuer == "" {
|
if upstreamIssuer == "" {
|
||||||
plog.Warning(
|
plog.Warning(
|
||||||
@ -209,7 +206,7 @@ func getUsernameFromUpstreamIDToken(
|
|||||||
"upstreamName", upstreamIDPConfig.GetName(),
|
"upstreamName", upstreamIDPConfig.GetName(),
|
||||||
"issClaim", upstreamIssuer,
|
"issClaim", upstreamIssuer,
|
||||||
)
|
)
|
||||||
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
|
||||||
}
|
}
|
||||||
upstreamIssuerAsString, ok := upstreamIssuer.(string)
|
upstreamIssuerAsString, ok := upstreamIssuer.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -218,10 +215,32 @@ func getUsernameFromUpstreamIDToken(
|
|||||||
"upstreamName", upstreamIDPConfig.GetName(),
|
"upstreamName", upstreamIDPConfig.GetName(),
|
||||||
"issClaim", upstreamIssuer,
|
"issClaim", upstreamIssuer,
|
||||||
)
|
)
|
||||||
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
||||||
}
|
}
|
||||||
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
|
|
||||||
usernameClaim = defaultUpstreamUsernameClaim
|
subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim]
|
||||||
|
if !ok {
|
||||||
|
plog.Warning(
|
||||||
|
"no subject claim in upstream ID token",
|
||||||
|
"upstreamName", upstreamIDPConfig.GetName(),
|
||||||
|
)
|
||||||
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "no subject claim in upstream ID token")
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreamSubject, ok := subjectAsInterface.(string)
|
||||||
|
if !ok {
|
||||||
|
plog.Warning(
|
||||||
|
"subject claim in upstream ID token has invalid format",
|
||||||
|
"upstreamName", upstreamIDPConfig.GetName(),
|
||||||
|
)
|
||||||
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
|
||||||
|
}
|
||||||
|
|
||||||
|
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject)
|
||||||
|
|
||||||
|
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
||||||
|
if usernameClaim == "" {
|
||||||
|
return subject, subject, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
usernameAsInterface, ok := idTokenClaims[usernameClaim]
|
usernameAsInterface, ok := idTokenClaims[usernameClaim]
|
||||||
@ -232,7 +251,7 @@ func getUsernameFromUpstreamIDToken(
|
|||||||
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
||||||
"usernameClaim", usernameClaim,
|
"usernameClaim", usernameClaim,
|
||||||
)
|
)
|
||||||
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
|
||||||
}
|
}
|
||||||
|
|
||||||
username, ok := usernameAsInterface.(string)
|
username, ok := usernameAsInterface.(string)
|
||||||
@ -243,10 +262,10 @@ func getUsernameFromUpstreamIDToken(
|
|||||||
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
||||||
"usernameClaim", usernameClaim,
|
"usernameClaim", usernameClaim,
|
||||||
)
|
)
|
||||||
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s%s", user, username), nil
|
return subject, username, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGroupsFromUpstreamIDToken(
|
func getGroupsFromUpstreamIDToken(
|
||||||
@ -283,19 +302,20 @@ func getGroupsFromUpstreamIDToken(
|
|||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeDownstreamSession(username string, groups []string) *openid.DefaultSession {
|
func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession {
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
openIDSession := &openid.DefaultSession{
|
openIDSession := &openid.DefaultSession{
|
||||||
Claims: &jwt.IDTokenClaims{
|
Claims: &jwt.IDTokenClaims{
|
||||||
Subject: username,
|
Subject: subject,
|
||||||
RequestedAt: now,
|
RequestedAt: now,
|
||||||
AuthTime: now,
|
AuthTime: now,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if groups != nil {
|
|
||||||
openIDSession.Claims.Extra = map[string]interface{}{
|
openIDSession.Claims.Extra = map[string]interface{}{
|
||||||
downstreamGroupsClaim: groups,
|
idTokenUsernameClaim: username,
|
||||||
}
|
}
|
||||||
|
if groups != nil {
|
||||||
|
openIDSession.Claims.Extra[downstreamGroupsClaim] = groups
|
||||||
}
|
}
|
||||||
return openIDSession
|
return openIDSession
|
||||||
}
|
}
|
||||||
|
@ -133,6 +133,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
wantRedirectLocationRegexp string
|
wantRedirectLocationRegexp string
|
||||||
wantDownstreamGrantedScopes []string
|
wantDownstreamGrantedScopes []string
|
||||||
wantDownstreamIDTokenSubject string
|
wantDownstreamIDTokenSubject string
|
||||||
|
wantDownstreamIDTokenUsername string
|
||||||
wantDownstreamIDTokenGroups []string
|
wantDownstreamIDTokenGroups []string
|
||||||
wantDownstreamRequestedScopes []string
|
wantDownstreamRequestedScopes []string
|
||||||
wantDownstreamNonce string
|
wantDownstreamNonce string
|
||||||
@ -150,7 +151,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
wantStatus: http.StatusFound,
|
wantStatus: http.StatusFound,
|
||||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||||
wantBody: "",
|
wantBody: "",
|
||||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||||
|
wantDownstreamIDTokenUsername: upstreamUsername,
|
||||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||||
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
||||||
@ -169,6 +171,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||||
wantBody: "",
|
wantBody: "",
|
||||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||||
|
wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||||
wantDownstreamIDTokenGroups: nil,
|
wantDownstreamIDTokenGroups: nil,
|
||||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||||
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
||||||
@ -186,7 +189,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
wantStatus: http.StatusFound,
|
wantStatus: http.StatusFound,
|
||||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||||
wantBody: "",
|
wantBody: "",
|
||||||
wantDownstreamIDTokenSubject: upstreamSubject,
|
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||||
|
wantDownstreamIDTokenUsername: upstreamSubject,
|
||||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||||
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
|
||||||
@ -312,7 +316,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
wantStatus: http.StatusFound,
|
wantStatus: http.StatusFound,
|
||||||
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
|
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
|
||||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
wantDownstreamIDTokenUsername: upstreamUsername,
|
||||||
|
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||||
wantDownstreamRequestedScopes: []string{"profile", "email"},
|
wantDownstreamRequestedScopes: []string{"profile", "email"},
|
||||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||||
wantDownstreamNonce: downstreamNonce,
|
wantDownstreamNonce: downstreamNonce,
|
||||||
@ -333,7 +338,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
wantStatus: http.StatusFound,
|
wantStatus: http.StatusFound,
|
||||||
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState,
|
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid%20offline_access&state=` + happyDownstreamState,
|
||||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
wantDownstreamIDTokenUsername: upstreamUsername,
|
||||||
|
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||||
wantDownstreamRequestedScopes: []string{"openid", "offline_access"},
|
wantDownstreamRequestedScopes: []string{"openid", "offline_access"},
|
||||||
wantDownstreamGrantedScopes: []string{"openid", "offline_access"},
|
wantDownstreamGrantedScopes: []string{"openid", "offline_access"},
|
||||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||||
@ -522,6 +528,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
authcodeDataAndSignature[1], // Authcode store key is authcode signature
|
authcodeDataAndSignature[1], // Authcode store key is authcode signature
|
||||||
test.wantDownstreamGrantedScopes,
|
test.wantDownstreamGrantedScopes,
|
||||||
test.wantDownstreamIDTokenSubject,
|
test.wantDownstreamIDTokenSubject,
|
||||||
|
test.wantDownstreamIDTokenUsername,
|
||||||
test.wantDownstreamIDTokenGroups,
|
test.wantDownstreamIDTokenGroups,
|
||||||
test.wantDownstreamRequestedScopes,
|
test.wantDownstreamRequestedScopes,
|
||||||
)
|
)
|
||||||
@ -742,6 +749,7 @@ func validateAuthcodeStorage(
|
|||||||
storeKey string,
|
storeKey string,
|
||||||
wantDownstreamGrantedScopes []string,
|
wantDownstreamGrantedScopes []string,
|
||||||
wantDownstreamIDTokenSubject string,
|
wantDownstreamIDTokenSubject string,
|
||||||
|
wantDownstreamIDTokenUsername string,
|
||||||
wantDownstreamIDTokenGroups []string,
|
wantDownstreamIDTokenGroups []string,
|
||||||
wantDownstreamRequestedScopes []string,
|
wantDownstreamRequestedScopes []string,
|
||||||
) (*fosite.Request, *openid.DefaultSession) {
|
) (*fosite.Request, *openid.DefaultSession) {
|
||||||
@ -778,13 +786,14 @@ func validateAuthcodeStorage(
|
|||||||
// Now confirm the ID token claims.
|
// Now confirm the ID token claims.
|
||||||
actualClaims := storedSessionFromAuthcode.Claims
|
actualClaims := storedSessionFromAuthcode.Claims
|
||||||
|
|
||||||
// Check the user's identity, which are put into the downstream ID token's subject and groups claims.
|
// Check the user's identity, which are put into the downstream ID token's subject, username and groups claims.
|
||||||
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
|
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
|
||||||
|
require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"])
|
||||||
if wantDownstreamIDTokenGroups != nil {
|
if wantDownstreamIDTokenGroups != nil {
|
||||||
require.Len(t, actualClaims.Extra, 1)
|
require.Len(t, actualClaims.Extra, 2)
|
||||||
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
|
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
|
||||||
} else {
|
} else {
|
||||||
require.Empty(t, actualClaims.Extra)
|
require.Len(t, actualClaims.Extra, 1)
|
||||||
require.NotContains(t, actualClaims.Extra, "groups")
|
require.NotContains(t, actualClaims.Extra, "groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user