Fix some tests in supervisor_login_test.go

This commit is contained in:
Ryan Richard 2023-06-26 15:26:24 -07:00
parent 98ee9f0979
commit 0f23931fe4
2 changed files with 37 additions and 24 deletions

View File

@ -116,10 +116,8 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
switch customSessionData.ProviderType { switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC: case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, session, idpLister, grantedScopes, clientID) return upstreamOIDCRefresh(ctx, idpLister, session, grantedScopes, clientID)
case psession.ProviderTypeLDAP: case psession.ProviderTypeLDAP, psession.ProviderTypeActiveDirectory:
return upstreamLDAPRefresh(ctx, idpLister, session, grantedScopes, clientID)
case psession.ProviderTypeActiveDirectory:
return upstreamLDAPRefresh(ctx, idpLister, session, grantedScopes, clientID) return upstreamLDAPRefresh(ctx, idpLister, session, grantedScopes, clientID)
default: default:
return errorsx.WithStack(errMissingUpstreamSessionInternalError()) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
@ -129,8 +127,8 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
//nolint:funlen //nolint:funlen
func upstreamOIDCRefresh( func upstreamOIDCRefresh(
ctx context.Context, ctx context.Context,
session *psession.PinnipedSession,
idpLister federationdomainproviders.FederationDomainIdentityProvidersListerI, idpLister federationdomainproviders.FederationDomainIdentityProvidersListerI,
session *psession.PinnipedSession,
grantedScopes []string, grantedScopes []string,
clientID string, clientID string,
) error { ) error {
@ -215,10 +213,12 @@ func upstreamOIDCRefresh(
// but if it is, verify that the transformed version of it hasn't changed. // but if it is, verify that the transformed version of it hasn't changed.
refreshedUntransformedUsername, hasRefreshedUntransformedUsername := getString(mergedClaims, p.Provider.GetUsernameClaim()) refreshedUntransformedUsername, hasRefreshedUntransformedUsername := getString(mergedClaims, p.Provider.GetUsernameClaim())
oldUntransformedUsername := s.UpstreamUsername
oldUntransformedGroups := s.UpstreamGroups
if !hasRefreshedUntransformedUsername { if !hasRefreshedUntransformedUsername {
// If we could not get a new username, then we still need the untransformed username to be able to // If we could not get a new username, then we still need the untransformed username to be able to
// run the transformations again, so fetch the original untransformed username from the session. // run the transformations again, so fetch the original untransformed username from the session.
refreshedUntransformedUsername = s.UpstreamUsername refreshedUntransformedUsername = oldUntransformedUsername
} }
if refreshedUntransformedGroups == nil { if refreshedUntransformedGroups == nil {
// If we could not get a new list of groups, then we still need the untransformed groups list to be able to // If we could not get a new list of groups, then we still need the untransformed groups list to be able to
@ -227,7 +227,7 @@ func upstreamOIDCRefresh(
// because a transformation policy may want to reject the authentication based on the group memberships, even // because a transformation policy may want to reject the authentication based on the group memberships, even
// though the group memberships will not be shared with the client (in the code below) due to the groups scope // though the group memberships will not be shared with the client (in the code below) due to the groups scope
// not being granted. // not being granted.
refreshedUntransformedGroups = s.UpstreamGroups refreshedUntransformedGroups = oldUntransformedGroups
} }
oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session) oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session)
@ -256,7 +256,7 @@ func upstreamOIDCRefresh(
if groupsScopeGranted { if groupsScopeGranted {
warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID) warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID)
// Replace the old value with the new value. // Replace the old value for the downstream groups in the user's session with the new value.
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups
} }
@ -385,11 +385,12 @@ func upstreamLDAPRefresh(
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
oldUntransformedUsername := s.UpstreamUsername oldUntransformedUsername := s.UpstreamUsername
oldUntransformedGroups := s.UpstreamGroups
refreshedUntransformedGroups, err := p.Provider.PerformRefresh(ctx, upstreamprovider.RefreshAttributes{ refreshedUntransformedGroups, err := p.Provider.PerformRefresh(ctx, upstreamprovider.RefreshAttributes{
Username: oldUntransformedUsername, Username: oldUntransformedUsername,
Subject: session.Fosite.Claims.Subject, Subject: session.Fosite.Claims.Subject,
DN: dn, DN: dn,
Groups: s.UpstreamGroups, Groups: oldUntransformedGroups,
AdditionalAttributes: additionalAttributes, AdditionalAttributes: additionalAttributes,
GrantedScopes: grantedScopes, GrantedScopes: grantedScopes,
}) })
@ -413,7 +414,7 @@ func upstreamLDAPRefresh(
if groupsScopeGranted { if groupsScopeGranted {
warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID) warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID)
// Replace the old value with the new value. // Replace the old value for the downstream groups in the user's session with the new value.
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups
} }

View File

@ -297,11 +297,13 @@ func TestSupervisorLogin_Browser(t *testing.T) {
wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" }, wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" },
wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups, wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups,
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string { editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string {
// even if we update this group to the wrong thing, we expect that it will return to the correct // Even if we update this group to the some names that did not come from the OIDC server,
// value after we refresh. // we expect that it will return to the real groups from the OIDC server after we refresh.
// However if there are no expected groups then they will not update, so we should skip this. // However if there are no expected groups then they will not update, so we should skip this.
if len(env.SupervisorUpstreamOIDC.ExpectedGroups) > 0 { if len(env.SupervisorUpstreamOIDC.ExpectedGroups) > 0 {
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"} initialGroupMembership := []string{"some-wrong-group", "some-other-group"}
sessionData.Custom.UpstreamGroups = initialGroupMembership // upstream group names in session
sessionData.Fosite.Claims.Extra["groups"] = initialGroupMembership // downstream group names in session
} }
return env.SupervisorUpstreamOIDC.ExpectedGroups return env.SupervisorUpstreamOIDC.ExpectedGroups
}, },
@ -450,9 +452,11 @@ func TestSupervisorLogin_Browser(t *testing.T) {
) )
}, },
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string { editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string {
// even if we update this group to the wrong thing, we expect that it will return to the correct // Even if we update this group to the some names that did not come from the LDAP server,
// value after we refresh. // we expect that it will return to the real groups from the LDAP server after we refresh.
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"} initialGroupMembership := []string{"some-wrong-group", "some-other-group"}
sessionData.Custom.UpstreamGroups = initialGroupMembership // upstream group names in session
sessionData.Fosite.Claims.Extra["groups"] = initialGroupMembership // downstream group names in session
return env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs return env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs
}, },
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) { breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
@ -656,11 +660,16 @@ func TestSupervisorLogin_Browser(t *testing.T) {
) )
}, },
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string { editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string {
// update the list of groups to the wrong thing and see that they do not get updated because // Update the list of groups to some groups that would not come from the real LDAP queries,
// skip group refresh is set // and see that these become the user's new groups after refresh, because LDAP skip group refresh is set.
wrongGroups := []string{"some-wrong-group", "some-other-group"} // Since we are skipping the LDAP group query during refresh, the refresh should use what the session
sessionData.Fosite.Claims.Extra["groups"] = wrongGroups // says is the original list of untransformed groups from the initial login, and then perform the
return wrongGroups // transformations on them again. However, since there are no transformations configured, they will not
// be changed by any transformations in this case.
initialGroupMembership := []string{"some-wrong-group", "some-other-group"} // these groups are not in LDAP server
sessionData.Custom.UpstreamGroups = initialGroupMembership // upstream group names in session
sessionData.Fosite.Claims.Extra["groups"] = initialGroupMembership // downstream group names in session
return initialGroupMembership // these are the expected groups after the refresh is performed
}, },
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) { breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
customSessionData := pinnipedSession.Custom customSessionData := pinnipedSession.Custom
@ -709,9 +718,12 @@ func TestSupervisorLogin_Browser(t *testing.T) {
) )
}, },
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string { editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) []string {
// even if we update this group to the wrong thing, we expect that it will return to the correct // Even if we update this group to the some names that did not come from the LDAP server,
// value (no groups) after we refresh. // we expect that it will return to the real groups from the LDAP server after we refresh,
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"} // which in this case is no groups since this test uses a group search base which results in no groups.
initialGroupMembership := []string{"some-wrong-group", "some-other-group"}
sessionData.Custom.UpstreamGroups = initialGroupMembership // upstream group names in session
sessionData.Fosite.Claims.Extra["groups"] = initialGroupMembership // downstream group names in session
return []string{} return []string{}
}, },
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) { breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {