PR feedback-- omit empty groups, keep groups as nil until last minute

Also log keys and values for claims
This commit is contained in:
Margo Crawford 2021-01-14 15:11:00 -08:00
parent 6fce1bd6bb
commit d11a73c519
2 changed files with 12 additions and 17 deletions

View File

@ -258,7 +258,7 @@ func getGroupsFromUpstreamIDToken(
) ([]string, error) { ) ([]string, error) {
groupsClaim := upstreamIDPConfig.GetGroupsClaim() groupsClaim := upstreamIDPConfig.GetGroupsClaim()
if groupsClaim == "" { if groupsClaim == "" {
return []string{}, nil return nil, nil
} }
groupsAsInterface, ok := idTokenClaims[groupsClaim] groupsAsInterface, ok := idTokenClaims[groupsClaim]
@ -269,7 +269,7 @@ func getGroupsFromUpstreamIDToken(
"configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(),
"groupsClaim", groupsClaim, "groupsClaim", groupsClaim,
) )
return []string{}, nil // the upstream IDP may have omitted the claim if the user has no groups return nil, nil // the upstream IDP may have omitted the claim if the user has no groups
} }
groupsAsArray, okAsArray := extractGroups(groupsAsInterface) groupsAsArray, okAsArray := extractGroups(groupsAsInterface)
@ -302,13 +302,15 @@ func extractGroups(groupsAsInterface interface{}) ([]string, bool) {
return nil, false return nil, false
} }
groupsAsStrings := make([]string, len(groupsAsInterfaceArray)) var groupsAsStrings []string
for i, groupAsInterface := range groupsAsInterfaceArray { for _, groupAsInterface := range groupsAsInterfaceArray {
groupAsString, okAsString := groupAsInterface.(string) groupAsString, okAsString := groupAsInterface.(string)
if !okAsString { if !okAsString {
return nil, false return nil, false
} }
groupsAsStrings[i] = groupAsString if groupAsString != "" {
groupsAsStrings = append(groupsAsStrings, groupAsString)
}
} }
return groupsAsStrings, true return groupsAsStrings, true
@ -323,6 +325,9 @@ func makeDownstreamSession(subject string, username string, groups []string) *op
AuthTime: now, AuthTime: now,
}, },
} }
if groups == nil {
groups = []string{}
}
openIDSession.Claims.Extra = map[string]interface{}{ openIDSession.Claims.Extra = map[string]interface{}{
oidc.DownstreamUsernameClaim: username, oidc.DownstreamUsernameClaim: username,
oidc.DownstreamGroupsClaim: groups, oidc.DownstreamGroupsClaim: groups,

View File

@ -102,12 +102,12 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
if err := validated.Claims(&validatedClaims); err != nil { if err := validated.Claims(&validatedClaims); err != nil {
return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err) return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal id token claims", err)
} }
plog.All("claims from ID token", "providerName", p.Name, "claims", listClaims(validatedClaims)) plog.All("claims from ID token", "providerName", p.Name, "claims", validatedClaims)
if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil { if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil {
return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err) return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err)
} }
plog.All("claims from ID token and userinfo", "providerName", p.Name, "claims", listClaims(validatedClaims)) plog.All("claims from ID token and userinfo", "providerName", p.Name, "claims", validatedClaims)
return &oidctypes.Token{ return &oidctypes.Token{
AccessToken: &oidctypes.AccessToken{ AccessToken: &oidctypes.AccessToken{
@ -162,13 +162,3 @@ func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, c
return nil return nil
} }
func listClaims(claims map[string]interface{}) []string {
list := make([]string, len(claims))
i := 0
for claim := range claims {
list[i] = claim
i++
}
return list
}