Upstream ldap group refresh:
- Doing it inline on the refresh request
This commit is contained in:
parent
46dd73de70
commit
013b521838
@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface {
|
||||
authenticators.UserAuthenticator
|
||||
|
||||
// PerformRefresh performs a refresh against the upstream LDAP identity provider
|
||||
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) error
|
||||
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) ([]string, error)
|
||||
}
|
||||
|
||||
type StoredRefreshAttributes struct {
|
||||
|
@ -301,7 +301,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
// run PerformRefresh
|
||||
err = p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
||||
groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
||||
Username: username,
|
||||
Subject: subject,
|
||||
DN: dn,
|
||||
@ -312,6 +312,10 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
"Upstream refresh failed.").WithWrap(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
// If we got groups back, then replace the old value with the new value.
|
||||
if groups != nil {
|
||||
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1339,6 +1339,60 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path refresh grant when the upstream refresh returns new group memberships from LDAP, it updates groups",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: []string{"new-group1", "new-group2", "new-group3"},
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
customSessionData: happyLDAPCustomSessionData,
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
|
||||
happyLDAPCustomSessionData,
|
||||
),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantStatus: http.StatusOK,
|
||||
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
|
||||
wantRequestedScopes: []string{"openid", "offline_access"},
|
||||
wantGrantedScopes: []string{"openid", "offline_access"},
|
||||
wantGroups: []string{"new-group1", "new-group2", "new-group3"},
|
||||
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
|
||||
wantCustomSessionDataStored: happyLDAPCustomSessionData,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path refresh grant when the upstream refresh returns empty list of group memberships from LDAP, it updates groups to an empty list",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: []string{},
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
customSessionData: happyLDAPCustomSessionData,
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
|
||||
happyLDAPCustomSessionData,
|
||||
),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantStatus: http.StatusOK,
|
||||
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
|
||||
wantRequestedScopes: []string{"openid", "offline_access"},
|
||||
wantGrantedScopes: []string{"openid", "offline_access"},
|
||||
wantGroups: []string{},
|
||||
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
|
||||
wantCustomSessionDataStored: happyLDAPCustomSessionData,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error from refresh grant when the upstream refresh does not return new group memberships from the merged ID token and userinfo results by returning group claim with illegal nil value",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
|
||||
@ -1970,6 +2024,7 @@ func TestRefreshGrant(t *testing.T) {
|
||||
Name: ldapUpstreamName,
|
||||
ResourceUID: ldapUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: goodGroups,
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
@ -1990,6 +2045,7 @@ func TestRefreshGrant(t *testing.T) {
|
||||
Name: activeDirectoryUpstreamName,
|
||||
ResourceUID: activeDirectoryUpstreamResourceUID,
|
||||
URL: ldapUpstreamURL,
|
||||
PerformRefreshGroups: goodGroups,
|
||||
}),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
|
@ -101,6 +101,7 @@ type TestUpstreamLDAPIdentityProvider struct {
|
||||
performRefreshCallCount int
|
||||
performRefreshArgs []*PerformRefreshArgs
|
||||
PerformRefreshErr error
|
||||
PerformRefreshGroups []string
|
||||
}
|
||||
|
||||
var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{}
|
||||
@ -121,7 +122,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
|
||||
return u.URL
|
||||
}
|
||||
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
|
||||
if u.performRefreshArgs == nil {
|
||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||
}
|
||||
@ -133,9 +134,9 @@ func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, s
|
||||
ExpectedSubject: storedRefreshAttributes.Subject,
|
||||
})
|
||||
if u.PerformRefreshErr != nil {
|
||||
return u.PerformRefreshErr
|
||||
return nil, u.PerformRefreshErr
|
||||
}
|
||||
return nil
|
||||
return u.PerformRefreshGroups, nil
|
||||
}
|
||||
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int {
|
||||
|
@ -170,61 +170,11 @@ func (p *Provider) GetConfig() ProviderConfig {
|
||||
return p.c
|
||||
}
|
||||
|
||||
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
|
||||
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
|
||||
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
|
||||
userDN := storedRefreshAttributes.DN
|
||||
|
||||
searchResult, err := p.performRefresh(ctx, userDN)
|
||||
if err != nil {
|
||||
p.traceRefreshFailure(t, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// if any more or less than one entry, error.
|
||||
// we don't need to worry about logging this because we know it's a dn.
|
||||
if len(searchResult.Entries) != 1 {
|
||||
return fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
|
||||
userDN, len(searchResult.Entries),
|
||||
)
|
||||
}
|
||||
|
||||
userEntry := searchResult.Entries[0]
|
||||
if len(userEntry.DN) == 0 {
|
||||
return fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
|
||||
}
|
||||
|
||||
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if newUsername != storedRefreshAttributes.Username {
|
||||
return fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
|
||||
userDN, storedRefreshAttributes.Username, newUsername,
|
||||
)
|
||||
}
|
||||
|
||||
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
|
||||
if newSubject != storedRefreshAttributes.Subject {
|
||||
return fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
|
||||
}
|
||||
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
|
||||
err = validateFunc(userEntry, storedRefreshAttributes)
|
||||
if err != nil {
|
||||
return fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
|
||||
}
|
||||
}
|
||||
// we checked that the user still exists and their information is the same, so just return.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.SearchResult, error) {
|
||||
search := p.refreshUserSearchRequest(userDN)
|
||||
|
||||
conn, err := p.dial(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
|
||||
@ -236,6 +186,65 @@ func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.Sea
|
||||
return nil, fmt.Errorf(`error binding as %q before user search: %w`, p.c.BindUsername, err)
|
||||
}
|
||||
|
||||
searchResult, err := p.performUserRefresh(conn, userDN)
|
||||
if err != nil {
|
||||
p.traceRefreshFailure(t, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if any more or less than one entry, error.
|
||||
// we don't need to worry about logging this because we know it's a dn.
|
||||
if len(searchResult.Entries) != 1 {
|
||||
return nil, fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
|
||||
userDN, len(searchResult.Entries),
|
||||
)
|
||||
}
|
||||
|
||||
userEntry := searchResult.Entries[0]
|
||||
if len(userEntry.DN) == 0 {
|
||||
return nil, fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
|
||||
}
|
||||
|
||||
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if newUsername != storedRefreshAttributes.Username {
|
||||
return nil, fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
|
||||
userDN, storedRefreshAttributes.Username, newUsername,
|
||||
)
|
||||
}
|
||||
|
||||
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
|
||||
if newSubject != storedRefreshAttributes.Subject {
|
||||
return nil, fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
|
||||
}
|
||||
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
|
||||
err = validateFunc(userEntry, storedRefreshAttributes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have group search configured, search for groups to update the value.
|
||||
if len(p.c.GroupSearch.Base) > 0 {
|
||||
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(mappedGroupNames)
|
||||
return mappedGroupNames, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *Provider) performUserRefresh(conn Conn, userDN string) (*ldap.SearchResult, error) {
|
||||
search := p.refreshUserSearchRequest(userDN)
|
||||
|
||||
searchResult, err := conn.Search(search)
|
||||
|
||||
if err != nil {
|
||||
@ -455,7 +464,7 @@ func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, er
|
||||
groupAttributeName = distinguishedNameAttributeName
|
||||
}
|
||||
|
||||
var groups []string
|
||||
groups := []string{}
|
||||
entries:
|
||||
for _, groupEntry := range searchResult.Entries {
|
||||
if len(groupEntry.DN) == 0 {
|
||||
|
@ -1100,6 +1100,18 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
Controls: nil, // don't need paging because we set the SizeLimit so small
|
||||
}
|
||||
|
||||
expectedGroupSearch := &ldap.SearchRequest{
|
||||
BaseDN: testGroupSearchBase,
|
||||
Scope: ldap.ScopeWholeSubtree,
|
||||
DerefAliases: ldap.NeverDerefAliases,
|
||||
SizeLimit: 0, // unlimited size because we will search with paging
|
||||
TimeLimit: 90,
|
||||
TypesOnly: false,
|
||||
Filter: testGroupSearchFilterInterpolated,
|
||||
Attributes: []string{testGroupSearchGroupNameAttribute},
|
||||
Controls: nil, // nil because ldap.SearchWithPaging() will set the appropriate controls for us
|
||||
}
|
||||
|
||||
happyPathUserSearchResult := &ldap.SearchResult{
|
||||
Entries: []*ldap.Entry{
|
||||
{
|
||||
@ -1146,6 +1158,7 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
setupMocks func(conn *mockldapconn.MockConn)
|
||||
dialError error
|
||||
wantErr string
|
||||
wantGroups []string
|
||||
}{
|
||||
{
|
||||
name: "happy path where searching the dn returns a single entry",
|
||||
@ -1156,6 +1169,89 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path where group search returns groups",
|
||||
providerConfig: &ProviderConfig{
|
||||
Name: "some-provider-name",
|
||||
Host: testHost,
|
||||
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||
ConnectionProtocol: TLS,
|
||||
BindUsername: testBindUsername,
|
||||
BindPassword: testBindPassword,
|
||||
UserSearch: UserSearchConfig{
|
||||
Base: testUserSearchBase,
|
||||
UIDAttribute: testUserSearchUIDAttribute,
|
||||
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||
},
|
||||
GroupSearch: GroupSearchConfig{
|
||||
Base: testGroupSearchBase,
|
||||
Filter: testGroupSearchFilter,
|
||||
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||
},
|
||||
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||
},
|
||||
},
|
||||
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
|
||||
Entries: []*ldap.Entry{
|
||||
{
|
||||
DN: testGroupSearchResultDNValue1,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue1}),
|
||||
},
|
||||
},
|
||||
{
|
||||
DN: testGroupSearchResultDNValue2,
|
||||
Attributes: []*ldap.EntryAttribute{
|
||||
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue2}),
|
||||
},
|
||||
},
|
||||
},
|
||||
Referrals: []string{}, // note that we are not following referrals at this time
|
||||
Controls: []ldap.Control{},
|
||||
}, nil).Times(1)
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2},
|
||||
},
|
||||
{
|
||||
name: "happy path where group search returns no groups",
|
||||
providerConfig: &ProviderConfig{
|
||||
Name: "some-provider-name",
|
||||
Host: testHost,
|
||||
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||
ConnectionProtocol: TLS,
|
||||
BindUsername: testBindUsername,
|
||||
BindPassword: testBindPassword,
|
||||
UserSearch: UserSearchConfig{
|
||||
Base: testUserSearchBase,
|
||||
UIDAttribute: testUserSearchUIDAttribute,
|
||||
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||
},
|
||||
GroupSearch: GroupSearchConfig{
|
||||
Base: testGroupSearchBase,
|
||||
Filter: testGroupSearchFilter,
|
||||
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||
},
|
||||
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||
},
|
||||
},
|
||||
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
|
||||
Entries: []*ldap.Entry{},
|
||||
Referrals: []string{}, // note that we are not following referrals at this time
|
||||
Controls: []ldap.Control{},
|
||||
}, nil).Times(1)
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
wantGroups: []string{},
|
||||
},
|
||||
{
|
||||
name: "error where dial fails",
|
||||
providerConfig: providerConfig,
|
||||
@ -1421,6 +1517,37 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
},
|
||||
wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login",
|
||||
},
|
||||
{
|
||||
name: "group search returns an error",
|
||||
providerConfig: &ProviderConfig{
|
||||
Name: "some-provider-name",
|
||||
Host: testHost,
|
||||
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||
ConnectionProtocol: TLS,
|
||||
BindUsername: testBindUsername,
|
||||
BindPassword: testBindPassword,
|
||||
UserSearch: UserSearchConfig{
|
||||
Base: testUserSearchBase,
|
||||
UIDAttribute: testUserSearchUIDAttribute,
|
||||
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||
},
|
||||
GroupSearch: GroupSearchConfig{
|
||||
Base: testGroupSearchBase,
|
||||
Filter: testGroupSearchFilter,
|
||||
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||
},
|
||||
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||
},
|
||||
},
|
||||
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(nil, errors.New("some search error")).Times(1)
|
||||
conn.EXPECT().Close().Times(1)
|
||||
},
|
||||
wantErr: "error searching for group memberships for user with DN \"some-upstream-user-dn\": some search error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -1435,9 +1562,9 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
dialWasAttempted := false
|
||||
providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
dialWasAttempted = true
|
||||
require.Equal(t, providerConfig.Host, addr.Endpoint())
|
||||
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
|
||||
if tt.dialError != nil {
|
||||
return nil, tt.dialError
|
||||
}
|
||||
@ -1446,9 +1573,9 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
})
|
||||
|
||||
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
|
||||
ldapProvider := New(*providerConfig)
|
||||
ldapProvider := New(*tt.providerConfig)
|
||||
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
|
||||
err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
|
||||
groups, err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
|
||||
Username: testUserSearchResultUsernameAttributeValue,
|
||||
Subject: subject,
|
||||
DN: testUserSearchResultDNValue,
|
||||
@ -1461,6 +1588,7 @@ func TestUpstreamRefresh(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, true, dialWasAttempted)
|
||||
require.Equal(t, tt.wantGroups, groups)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -257,7 +257,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
|
||||
p.GroupSearch.Base = "ou=users,dc=pinniped,dc=dev" // there are no groups under this part of the tree
|
||||
})),
|
||||
wantAuthResponse: &authenticators.Response{
|
||||
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil},
|
||||
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
|
||||
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
|
||||
ExtraRefreshAttributes: map[string]string{},
|
||||
},
|
||||
@ -328,7 +328,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
|
||||
p.GroupSearch.Filter = "foobar={}" // foobar is not a valid attribute name for this LDAP server's schema
|
||||
})),
|
||||
wantAuthResponse: &authenticators.Response{
|
||||
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil},
|
||||
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
|
||||
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
|
||||
ExtraRefreshAttributes: map[string]string{},
|
||||
},
|
||||
|
@ -66,6 +66,9 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
||||
// Either revoke the user's session on the upstream provider, or manipulate the user's session
|
||||
// data in such a way that it should cause the next upstream refresh attempt to fail.
|
||||
breakRefreshSessionData func(t *testing.T, sessionData *psession.PinnipedSession, idpName, username string)
|
||||
// Edit the refresh session data between the initial login and the refresh, which is expected to
|
||||
// succeed.
|
||||
editRefreshSessionDataWithoutBreaking func(t *testing.T, sessionData *psession.PinnipedSession)
|
||||
}{
|
||||
{
|
||||
name: "oidc with default username and groups claim settings",
|
||||
@ -128,6 +131,14 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
||||
wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Issuer+"?sub=") + ".+",
|
||||
wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" },
|
||||
wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups,
|
||||
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession) {
|
||||
// even if we update this group to the wrong thing, we expect that it will return to the correct
|
||||
// value after we refresh.
|
||||
// However if there are no expected groups then they will not update, so we should skip this.
|
||||
if len(env.SupervisorUpstreamOIDC.ExpectedGroups) > 0 {
|
||||
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oidc without refresh token",
|
||||
@ -272,6 +283,11 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
||||
false,
|
||||
)
|
||||
},
|
||||
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession) {
|
||||
// even if we update this group to the wrong thing, we expect that it will return to the correct
|
||||
// value after we refresh.
|
||||
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
|
||||
},
|
||||
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
|
||||
customSessionData := pinnipedSession.Custom
|
||||
require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType)
|
||||
@ -1272,6 +1288,7 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
||||
testSupervisorLogin(t,
|
||||
tt.createIDP,
|
||||
tt.requestAuthorization,
|
||||
tt.editRefreshSessionDataWithoutBreaking,
|
||||
tt.breakRefreshSessionData,
|
||||
tt.createTestUser,
|
||||
tt.deleteTestUser,
|
||||
@ -1405,8 +1422,9 @@ func requireEventuallySuccessfulActiveDirectoryIdentityProviderConditions(t *tes
|
||||
func testSupervisorLogin(
|
||||
t *testing.T,
|
||||
createIDP func(t *testing.T) string,
|
||||
requestAuthorization func(t *testing.T, downstreamAuthorizeURL, downstreamCallbackURL, username, password string, httpClient *http.Client),
|
||||
breakRefreshSessionData func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName, username string),
|
||||
requestAuthorization func(t *testing.T, downstreamAuthorizeURL string, downstreamCallbackURL string, username string, password string, httpClient *http.Client),
|
||||
editRefreshSessionDataWithoutBreaking func(t *testing.T, pinnipedSession *psession.PinnipedSession),
|
||||
breakRefreshSessionData func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName string, username string),
|
||||
createTestUser func(t *testing.T) (string, string),
|
||||
deleteTestUser func(t *testing.T, username string),
|
||||
wantDownstreamIDTokenSubjectToMatch string,
|
||||
@ -1565,6 +1583,28 @@ func testSupervisorLogin(
|
||||
// token exchange on the original token
|
||||
doTokenExchange(t, &downstreamOAuth2Config, tokenResponse, httpClient, discovery)
|
||||
|
||||
if editRefreshSessionDataWithoutBreaking != nil {
|
||||
latestRefreshToken := tokenResponse.RefreshToken
|
||||
signatureOfLatestRefreshToken := getFositeDataSignature(t, latestRefreshToken)
|
||||
|
||||
// First use the latest downstream refresh token to look up the corresponding session in the Supervisor's storage.
|
||||
kubeClient := testlib.NewKubernetesClientset(t)
|
||||
supervisorSecretsClient := kubeClient.CoreV1().Secrets(env.SupervisorNamespace)
|
||||
oauthStore := oidc.NewKubeStorage(supervisorSecretsClient, oidc.DefaultOIDCTimeoutsConfiguration())
|
||||
storedRefreshSession, err := oauthStore.GetRefreshTokenSession(ctx, signatureOfLatestRefreshToken, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Next mutate the part of the session that is used during upstream refresh.
|
||||
pinnipedSession, ok := storedRefreshSession.GetSession().(*psession.PinnipedSession)
|
||||
require.True(t, ok, "should have been able to cast session data to PinnipedSession")
|
||||
|
||||
editRefreshSessionDataWithoutBreaking(t, pinnipedSession)
|
||||
|
||||
// Then save the mutated Secret back to Kubernetes.
|
||||
// There is no update function, so delete and create again at the same name.
|
||||
require.NoError(t, oauthStore.DeleteRefreshTokenSession(ctx, signatureOfLatestRefreshToken))
|
||||
require.NoError(t, oauthStore.CreateRefreshTokenSession(ctx, signatureOfLatestRefreshToken, storedRefreshSession))
|
||||
}
|
||||
// Use the refresh token to get new tokens
|
||||
refreshSource := downstreamOAuth2Config.TokenSource(oidcHTTPClientContext, &oauth2.Token{RefreshToken: tokenResponse.RefreshToken})
|
||||
refreshedTokenResponse, err := refreshSource.Token()
|
||||
|
Loading…
Reference in New Issue
Block a user