Allow upstream group claim values to be either arrays or strings

This commit is contained in:
Ryan Richard 2020-12-15 08:34:24 -08:00
parent 16dfab0aff
commit 43bb7117b7
2 changed files with 44 additions and 11 deletions

View File

@ -254,7 +254,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
func getGroupsFromUpstreamIDToken( func getGroupsFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{}, idTokenClaims map[string]interface{},
) ([]string, error) { ) (interface{}, error) {
groupsClaim := upstreamIDPConfig.GetGroupsClaim() groupsClaim := upstreamIDPConfig.GetGroupsClaim()
if groupsClaim == "" { if groupsClaim == "" {
return nil, nil return nil, nil
@ -271,8 +271,9 @@ func getGroupsFromUpstreamIDToken(
return nil, httperr.New(http.StatusUnprocessableEntity, "no groups claim in upstream ID token") return nil, httperr.New(http.StatusUnprocessableEntity, "no groups claim in upstream ID token")
} }
groups, ok := groupsAsInterface.([]string) groupsAsArray, okAsArray := groupsAsInterface.([]string)
if !ok { groupsAsString, okAsString := groupsAsInterface.(string)
if !okAsArray && !okAsString {
plog.Warning( plog.Warning(
"groups claim in upstream ID token has invalid format", "groups claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(), "upstreamName", upstreamIDPConfig.GetName(),
@ -282,10 +283,13 @@ func getGroupsFromUpstreamIDToken(
return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format") return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format")
} }
return groups, nil if okAsArray {
return groupsAsArray, nil
}
return groupsAsString, nil
} }
func makeDownstreamSession(subject string, username string, groups []string) *openid.DefaultSession { func makeDownstreamSession(subject string, username string, groups interface{}) *openid.DefaultSession {
now := time.Now().UTC() now := time.Now().UTC()
openIDSession := &openid.DefaultSession{ openIDSession := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{

View File

@ -134,7 +134,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantDownstreamGrantedScopes []string wantDownstreamGrantedScopes []string
wantDownstreamIDTokenSubject string wantDownstreamIDTokenSubject string
wantDownstreamIDTokenUsername string wantDownstreamIDTokenUsername string
wantDownstreamIDTokenGroups []string wantDownstreamIDTokenGroups interface{}
wantDownstreamRequestedScopes []string wantDownstreamRequestedScopes []string
wantDownstreamNonce string wantDownstreamNonce string
wantDownstreamPKCEChallenge string wantDownstreamPKCEChallenge string
@ -199,6 +199,25 @@ func TestCallbackEndpoint(t *testing.T) {
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
}, },
{
name: "upstream IDP's configured groups claim in the ID token has a non-array value",
idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, "notAnArrayGroup1 notAnArrayGroup2").Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenGroups: "notAnArrayGroup1 notAnArrayGroup2",
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
wantDownstreamNonce: downstreamNonce,
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
// Pre-upstream-exchange verification // Pre-upstream-exchange verification
{ {
@ -682,8 +701,8 @@ func happyUpstream() *upstreamOIDCIdentityProviderBuilder {
} }
} }
func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(claim string) *upstreamOIDCIdentityProviderBuilder { func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(value string) *upstreamOIDCIdentityProviderBuilder {
u.usernameClaim = claim u.usernameClaim = value
return u return u
} }
@ -750,7 +769,7 @@ func validateAuthcodeStorage(
wantDownstreamGrantedScopes []string, wantDownstreamGrantedScopes []string,
wantDownstreamIDTokenSubject string, wantDownstreamIDTokenSubject string,
wantDownstreamIDTokenUsername string, wantDownstreamIDTokenUsername string,
wantDownstreamIDTokenGroups []string, wantDownstreamIDTokenGroups interface{},
wantDownstreamRequestedScopes []string, wantDownstreamRequestedScopes []string,
) (*fosite.Request, *openid.DefaultSession) { ) (*fosite.Request, *openid.DefaultSession) {
t.Helper() t.Helper()
@ -789,9 +808,19 @@ func validateAuthcodeStorage(
// Check the user's identity, which are put into the downstream ID token's subject, username and groups claims. // Check the user's identity, which are put into the downstream ID token's subject, username and groups claims.
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"]) require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"])
if wantDownstreamIDTokenGroups != nil { if wantDownstreamIDTokenGroups != nil { //nolint:nestif // there are some nested if's here but its probably fine for a test
require.Len(t, actualClaims.Extra, 2) require.Len(t, actualClaims.Extra, 2)
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) wantArray, ok := wantDownstreamIDTokenGroups.([]string)
if ok {
require.ElementsMatch(t, wantArray, actualClaims.Extra["groups"])
} else {
wantString, ok := wantDownstreamIDTokenGroups.(string)
if ok {
require.Equal(t, wantString, actualClaims.Extra["groups"])
} else {
require.Fail(t, "wantDownstreamIDTokenGroups should be of type: either []string or string")
}
}
} else { } else {
require.Len(t, actualClaims.Extra, 1) require.Len(t, actualClaims.Extra, 1)
require.NotContains(t, actualClaims.Extra, "groups") require.NotContains(t, actualClaims.Extra, "groups")