fix token_handler_test.go

This commit is contained in:
Ryan Richard 2023-06-26 12:40:13 -07:00
parent 9d792352bf
commit b71e5964aa
2 changed files with 60 additions and 37 deletions

View File

@ -135,6 +135,8 @@ func upstreamOIDCRefresh(
clientID string, clientID string,
) error { ) error {
s := session.Custom s := session.Custom
groupsScopeGranted := slices.Contains(grantedScopes, oidcapi.ScopeGroups)
if s.OIDC == nil { if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError()) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
@ -186,15 +188,6 @@ func upstreamOIDCRefresh(
} }
mergedClaims := validatedTokens.IDToken.Claims mergedClaims := validatedTokens.IDToken.Claims
oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session)
if err != nil {
return err
}
oldTransformedGroups, err := getDownstreamGroupsFromPinnipedSession(session)
if err != nil {
return err
}
// To the extent possible, check that the user's basic identity hasn't changed. // To the extent possible, check that the user's basic identity hasn't changed.
err = validateSubjectAndIssuerUnchangedSinceInitialLogin(mergedClaims, session) err = validateSubjectAndIssuerUnchangedSinceInitialLogin(mergedClaims, session)
if err != nil { if err != nil {
@ -202,8 +195,7 @@ func upstreamOIDCRefresh(
} }
var refreshedUntransformedGroups []string var refreshedUntransformedGroups []string
groupsScope := slices.Contains(grantedScopes, oidcapi.ScopeGroups) if groupsScopeGranted {
if groupsScope {
// If possible, update the user's group memberships. The configured groups claim name (if there is one) may or // If possible, update the user's group memberships. The configured groups claim name (if there is one) may or
// may not be included in the newly fetched and merged claims. It could be missing due to a misconfiguration of the // may not be included in the newly fetched and merged claims. It could be missing due to a misconfiguration of the
// claim name. It could also be missing because the claim was originally found in the ID token during login, but // claim name. It could also be missing because the claim was originally found in the ID token during login, but
@ -231,9 +223,25 @@ func upstreamOIDCRefresh(
if refreshedUntransformedGroups == nil { if refreshedUntransformedGroups == nil {
// If we could not get a new list of groups, then we still need the untransformed groups list to be able to // If we could not get a new list of groups, then we still need the untransformed groups list to be able to
// run the transformations again, so fetch the original untransformed groups list from the session. // run the transformations again, so fetch the original untransformed groups list from the session.
// We should also run the transformations on the original groups even when the groups scope was not granted,
// because a transformation policy may want to reject the authentication based on the group memberships, even
// though the group memberships will not be shared with the client (in the code below) due to the groups scope
// not being granted.
refreshedUntransformedGroups = s.UpstreamGroups refreshedUntransformedGroups = s.UpstreamGroups
} }
oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session)
if err != nil {
return err
}
var oldTransformedGroups []string
if groupsScopeGranted {
oldTransformedGroups, err = getDownstreamGroupsFromPinnipedSession(session)
if err != nil {
return err
}
}
transformationResult, err := transformRefreshedIdentity(ctx, transformationResult, err := transformRefreshedIdentity(ctx,
p.Transforms, p.Transforms,
oldTransformedUsername, oldTransformedUsername,
@ -246,8 +254,11 @@ func upstreamOIDCRefresh(
return err return err
} }
if groupsScopeGranted {
warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID) warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID)
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = refreshedUntransformedGroups // Replace the old value with the new value.
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups
}
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in // Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding // the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
@ -334,19 +345,20 @@ func upstreamLDAPRefresh(
grantedScopes []string, grantedScopes []string,
clientID string, clientID string,
) error { ) error {
s := session.Custom
groupsScopeGranted := slices.Contains(grantedScopes, oidcapi.ScopeGroups)
oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session) oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session)
if err != nil { if err != nil {
return err return err
} }
subject := session.Fosite.Claims.Subject
var oldTransformedGroups []string var oldTransformedGroups []string
if slices.Contains(grantedScopes, oidcapi.ScopeGroups) { if groupsScopeGranted {
oldTransformedGroups, err = getDownstreamGroupsFromPinnipedSession(session) oldTransformedGroups, err = getDownstreamGroupsFromPinnipedSession(session)
if err != nil { if err != nil {
return err return err
} }
} }
s := session.Custom
validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != "" validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != ""
validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != "" validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != ""
@ -369,9 +381,13 @@ func upstreamLDAPRefresh(
return errorsx.WithStack(errMissingUpstreamSessionInternalError()) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
plog.Debug("attempting upstream refresh request",
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
oldUntransformedUsername := s.UpstreamUsername
refreshedUntransformedGroups, err := p.Provider.PerformRefresh(ctx, upstreamprovider.RefreshAttributes{ refreshedUntransformedGroups, err := p.Provider.PerformRefresh(ctx, upstreamprovider.RefreshAttributes{
Username: s.UpstreamUsername, Username: oldUntransformedUsername,
Subject: subject, Subject: session.Fosite.Claims.Subject,
DN: dn, DN: dn,
Groups: s.UpstreamGroups, Groups: s.UpstreamGroups,
AdditionalAttributes: additionalAttributes, AdditionalAttributes: additionalAttributes,
@ -386,7 +402,7 @@ func upstreamLDAPRefresh(
transformationResult, err := transformRefreshedIdentity(ctx, transformationResult, err := transformRefreshedIdentity(ctx,
p.Transforms, p.Transforms,
oldTransformedUsername, oldTransformedUsername,
s.UpstreamUsername, oldUntransformedUsername, // LDAP PerformRefresh validates that the username did not change, so this is also the refreshed upstream username
refreshedUntransformedGroups, refreshedUntransformedGroups,
s.ProviderName, s.ProviderName,
s.ProviderType, s.ProviderType,
@ -395,8 +411,7 @@ func upstreamLDAPRefresh(
return err return err
} }
groupsScope := slices.Contains(grantedScopes, oidcapi.ScopeGroups) if groupsScopeGranted {
if groupsScope {
warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID) warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID)
// Replace the old value with the new value. // Replace the old value with the new value.
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups

View File

@ -1764,6 +1764,8 @@ func TestRefreshGrant(t *testing.T) {
initialUpstreamOIDCRefreshTokenCustomSessionData := func() *psession.CustomSessionData { initialUpstreamOIDCRefreshTokenCustomSessionData := func() *psession.CustomSessionData {
return &psession.CustomSessionData{ return &psession.CustomSessionData{
Username: goodUsername, Username: goodUsername,
UpstreamUsername: goodUsername,
UpstreamGroups: goodGroups,
ProviderName: oidcUpstreamName, ProviderName: oidcUpstreamName,
ProviderUID: oidcUpstreamResourceUID, ProviderUID: oidcUpstreamResourceUID,
ProviderType: oidcUpstreamType, ProviderType: oidcUpstreamType,
@ -1778,6 +1780,8 @@ func TestRefreshGrant(t *testing.T) {
initialUpstreamOIDCAccessTokenCustomSessionData := func() *psession.CustomSessionData { initialUpstreamOIDCAccessTokenCustomSessionData := func() *psession.CustomSessionData {
return &psession.CustomSessionData{ return &psession.CustomSessionData{
Username: goodUsername, Username: goodUsername,
UpstreamUsername: goodUsername,
UpstreamGroups: goodGroups,
ProviderName: oidcUpstreamName, ProviderName: oidcUpstreamName,
ProviderUID: oidcUpstreamResourceUID, ProviderUID: oidcUpstreamResourceUID,
ProviderType: oidcUpstreamType, ProviderType: oidcUpstreamType,
@ -1918,6 +1922,8 @@ func TestRefreshGrant(t *testing.T) {
happyActiveDirectoryCustomSessionData := &psession.CustomSessionData{ happyActiveDirectoryCustomSessionData := &psession.CustomSessionData{
Username: goodUsername, Username: goodUsername,
UpstreamUsername: goodUsername,
UpstreamGroups: goodGroups,
ProviderUID: activeDirectoryUpstreamResourceUID, ProviderUID: activeDirectoryUpstreamResourceUID,
ProviderName: activeDirectoryUpstreamName, ProviderName: activeDirectoryUpstreamName,
ProviderType: activeDirectoryUpstreamType, ProviderType: activeDirectoryUpstreamType,
@ -1928,6 +1934,8 @@ func TestRefreshGrant(t *testing.T) {
happyLDAPCustomSessionData := &psession.CustomSessionData{ happyLDAPCustomSessionData := &psession.CustomSessionData{
Username: goodUsername, Username: goodUsername,
UpstreamUsername: goodUsername,
UpstreamGroups: goodGroups,
ProviderUID: ldapUpstreamResourceUID, ProviderUID: ldapUpstreamResourceUID,
ProviderName: ldapUpstreamName, ProviderName: ldapUpstreamName,
ProviderType: ldapUpstreamType, ProviderType: ldapUpstreamType,