Upstream ldap group refresh:
- Doing it inline on the refresh request
This commit is contained in:
parent
46dd73de70
commit
013b521838
@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface {
|
|||||||
authenticators.UserAuthenticator
|
authenticators.UserAuthenticator
|
||||||
|
|
||||||
// PerformRefresh performs a refresh against the upstream LDAP identity provider
|
// PerformRefresh performs a refresh against the upstream LDAP identity provider
|
||||||
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) error
|
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type StoredRefreshAttributes struct {
|
type StoredRefreshAttributes struct {
|
||||||
|
@ -301,7 +301,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
|||||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||||
}
|
}
|
||||||
// run PerformRefresh
|
// run PerformRefresh
|
||||||
err = p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
||||||
Username: username,
|
Username: username,
|
||||||
Subject: subject,
|
Subject: subject,
|
||||||
DN: dn,
|
DN: dn,
|
||||||
@ -312,6 +312,10 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
|||||||
"Upstream refresh failed.").WithWrap(err).
|
"Upstream refresh failed.").WithWrap(err).
|
||||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
}
|
}
|
||||||
|
// If we got groups back, then replace the old value with the new value.
|
||||||
|
if groups != nil {
|
||||||
|
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1339,6 +1339,60 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "happy path refresh grant when the upstream refresh returns new group memberships from LDAP, it updates groups",
|
||||||
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||||
|
Name: ldapUpstreamName,
|
||||||
|
ResourceUID: ldapUpstreamResourceUID,
|
||||||
|
URL: ldapUpstreamURL,
|
||||||
|
PerformRefreshGroups: []string{"new-group1", "new-group2", "new-group3"},
|
||||||
|
}),
|
||||||
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
|
customSessionData: happyLDAPCustomSessionData,
|
||||||
|
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
|
||||||
|
happyLDAPCustomSessionData,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
refreshRequest: refreshRequestInputs{
|
||||||
|
want: tokenEndpointResponseExpectedValues{
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
|
||||||
|
wantRequestedScopes: []string{"openid", "offline_access"},
|
||||||
|
wantGrantedScopes: []string{"openid", "offline_access"},
|
||||||
|
wantGroups: []string{"new-group1", "new-group2", "new-group3"},
|
||||||
|
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
|
||||||
|
wantCustomSessionDataStored: happyLDAPCustomSessionData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "happy path refresh grant when the upstream refresh returns empty list of group memberships from LDAP, it updates groups to an empty list",
|
||||||
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||||
|
Name: ldapUpstreamName,
|
||||||
|
ResourceUID: ldapUpstreamResourceUID,
|
||||||
|
URL: ldapUpstreamURL,
|
||||||
|
PerformRefreshGroups: []string{},
|
||||||
|
}),
|
||||||
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
|
customSessionData: happyLDAPCustomSessionData,
|
||||||
|
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
|
||||||
|
happyLDAPCustomSessionData,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
refreshRequest: refreshRequestInputs{
|
||||||
|
want: tokenEndpointResponseExpectedValues{
|
||||||
|
wantStatus: http.StatusOK,
|
||||||
|
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
|
||||||
|
wantRequestedScopes: []string{"openid", "offline_access"},
|
||||||
|
wantGrantedScopes: []string{"openid", "offline_access"},
|
||||||
|
wantGroups: []string{},
|
||||||
|
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
|
||||||
|
wantCustomSessionDataStored: happyLDAPCustomSessionData,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "error from refresh grant when the upstream refresh does not return new group memberships from the merged ID token and userinfo results by returning group claim with illegal nil value",
|
name: "error from refresh grant when the upstream refresh does not return new group memberships from the merged ID token and userinfo results by returning group claim with illegal nil value",
|
||||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
|
||||||
@ -1970,6 +2024,7 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
Name: ldapUpstreamName,
|
Name: ldapUpstreamName,
|
||||||
ResourceUID: ldapUpstreamResourceUID,
|
ResourceUID: ldapUpstreamResourceUID,
|
||||||
URL: ldapUpstreamURL,
|
URL: ldapUpstreamURL,
|
||||||
|
PerformRefreshGroups: goodGroups,
|
||||||
}),
|
}),
|
||||||
authcodeExchange: authcodeExchangeInputs{
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
@ -1990,6 +2045,7 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
Name: activeDirectoryUpstreamName,
|
Name: activeDirectoryUpstreamName,
|
||||||
ResourceUID: activeDirectoryUpstreamResourceUID,
|
ResourceUID: activeDirectoryUpstreamResourceUID,
|
||||||
URL: ldapUpstreamURL,
|
URL: ldapUpstreamURL,
|
||||||
|
PerformRefreshGroups: goodGroups,
|
||||||
}),
|
}),
|
||||||
authcodeExchange: authcodeExchangeInputs{
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
|
@ -101,6 +101,7 @@ type TestUpstreamLDAPIdentityProvider struct {
|
|||||||
performRefreshCallCount int
|
performRefreshCallCount int
|
||||||
performRefreshArgs []*PerformRefreshArgs
|
performRefreshArgs []*PerformRefreshArgs
|
||||||
PerformRefreshErr error
|
PerformRefreshErr error
|
||||||
|
PerformRefreshGroups []string
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{}
|
var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{}
|
||||||
@ -121,7 +122,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
|
|||||||
return u.URL
|
return u.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
|
||||||
if u.performRefreshArgs == nil {
|
if u.performRefreshArgs == nil {
|
||||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||||
}
|
}
|
||||||
@ -133,9 +134,9 @@ func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, s
|
|||||||
ExpectedSubject: storedRefreshAttributes.Subject,
|
ExpectedSubject: storedRefreshAttributes.Subject,
|
||||||
})
|
})
|
||||||
if u.PerformRefreshErr != nil {
|
if u.PerformRefreshErr != nil {
|
||||||
return u.PerformRefreshErr
|
return nil, u.PerformRefreshErr
|
||||||
}
|
}
|
||||||
return nil
|
return u.PerformRefreshGroups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int {
|
func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int {
|
||||||
|
@ -170,61 +170,11 @@ func (p *Provider) GetConfig() ProviderConfig {
|
|||||||
return p.c
|
return p.c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
|
||||||
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
|
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
|
||||||
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
|
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
|
||||||
userDN := storedRefreshAttributes.DN
|
userDN := storedRefreshAttributes.DN
|
||||||
|
|
||||||
searchResult, err := p.performRefresh(ctx, userDN)
|
|
||||||
if err != nil {
|
|
||||||
p.traceRefreshFailure(t, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// if any more or less than one entry, error.
|
|
||||||
// we don't need to worry about logging this because we know it's a dn.
|
|
||||||
if len(searchResult.Entries) != 1 {
|
|
||||||
return fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
|
|
||||||
userDN, len(searchResult.Entries),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
userEntry := searchResult.Entries[0]
|
|
||||||
if len(userEntry.DN) == 0 {
|
|
||||||
return fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
|
|
||||||
}
|
|
||||||
|
|
||||||
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if newUsername != storedRefreshAttributes.Username {
|
|
||||||
return fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
|
|
||||||
userDN, storedRefreshAttributes.Username, newUsername,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
|
|
||||||
if newSubject != storedRefreshAttributes.Subject {
|
|
||||||
return fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
|
|
||||||
}
|
|
||||||
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
|
|
||||||
err = validateFunc(userEntry, storedRefreshAttributes)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// we checked that the user still exists and their information is the same, so just return.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.SearchResult, error) {
|
|
||||||
search := p.refreshUserSearchRequest(userDN)
|
|
||||||
|
|
||||||
conn, err := p.dial(ctx)
|
conn, err := p.dial(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
|
return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
|
||||||
@ -236,6 +186,65 @@ func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.Sea
|
|||||||
return nil, fmt.Errorf(`error binding as %q before user search: %w`, p.c.BindUsername, err)
|
return nil, fmt.Errorf(`error binding as %q before user search: %w`, p.c.BindUsername, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
searchResult, err := p.performUserRefresh(conn, userDN)
|
||||||
|
if err != nil {
|
||||||
|
p.traceRefreshFailure(t, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// if any more or less than one entry, error.
|
||||||
|
// we don't need to worry about logging this because we know it's a dn.
|
||||||
|
if len(searchResult.Entries) != 1 {
|
||||||
|
return nil, fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
|
||||||
|
userDN, len(searchResult.Entries),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
userEntry := searchResult.Entries[0]
|
||||||
|
if len(userEntry.DN) == 0 {
|
||||||
|
return nil, fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
|
||||||
|
}
|
||||||
|
|
||||||
|
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if newUsername != storedRefreshAttributes.Username {
|
||||||
|
return nil, fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
|
||||||
|
userDN, storedRefreshAttributes.Username, newUsername,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
|
||||||
|
if newSubject != storedRefreshAttributes.Subject {
|
||||||
|
return nil, fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
|
||||||
|
}
|
||||||
|
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
|
||||||
|
err = validateFunc(userEntry, storedRefreshAttributes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have group search configured, search for groups to update the value.
|
||||||
|
if len(p.c.GroupSearch.Base) > 0 {
|
||||||
|
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sort.Strings(mappedGroupNames)
|
||||||
|
return mappedGroupNames, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) performUserRefresh(conn Conn, userDN string) (*ldap.SearchResult, error) {
|
||||||
|
search := p.refreshUserSearchRequest(userDN)
|
||||||
|
|
||||||
searchResult, err := conn.Search(search)
|
searchResult, err := conn.Search(search)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -455,7 +464,7 @@ func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, er
|
|||||||
groupAttributeName = distinguishedNameAttributeName
|
groupAttributeName = distinguishedNameAttributeName
|
||||||
}
|
}
|
||||||
|
|
||||||
var groups []string
|
groups := []string{}
|
||||||
entries:
|
entries:
|
||||||
for _, groupEntry := range searchResult.Entries {
|
for _, groupEntry := range searchResult.Entries {
|
||||||
if len(groupEntry.DN) == 0 {
|
if len(groupEntry.DN) == 0 {
|
||||||
|
@ -1100,6 +1100,18 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
Controls: nil, // don't need paging because we set the SizeLimit so small
|
Controls: nil, // don't need paging because we set the SizeLimit so small
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expectedGroupSearch := &ldap.SearchRequest{
|
||||||
|
BaseDN: testGroupSearchBase,
|
||||||
|
Scope: ldap.ScopeWholeSubtree,
|
||||||
|
DerefAliases: ldap.NeverDerefAliases,
|
||||||
|
SizeLimit: 0, // unlimited size because we will search with paging
|
||||||
|
TimeLimit: 90,
|
||||||
|
TypesOnly: false,
|
||||||
|
Filter: testGroupSearchFilterInterpolated,
|
||||||
|
Attributes: []string{testGroupSearchGroupNameAttribute},
|
||||||
|
Controls: nil, // nil because ldap.SearchWithPaging() will set the appropriate controls for us
|
||||||
|
}
|
||||||
|
|
||||||
happyPathUserSearchResult := &ldap.SearchResult{
|
happyPathUserSearchResult := &ldap.SearchResult{
|
||||||
Entries: []*ldap.Entry{
|
Entries: []*ldap.Entry{
|
||||||
{
|
{
|
||||||
@ -1146,6 +1158,7 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
setupMocks func(conn *mockldapconn.MockConn)
|
setupMocks func(conn *mockldapconn.MockConn)
|
||||||
dialError error
|
dialError error
|
||||||
wantErr string
|
wantErr string
|
||||||
|
wantGroups []string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path where searching the dn returns a single entry",
|
name: "happy path where searching the dn returns a single entry",
|
||||||
@ -1156,6 +1169,89 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
conn.EXPECT().Close().Times(1)
|
conn.EXPECT().Close().Times(1)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "happy path where group search returns groups",
|
||||||
|
providerConfig: &ProviderConfig{
|
||||||
|
Name: "some-provider-name",
|
||||||
|
Host: testHost,
|
||||||
|
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||||
|
ConnectionProtocol: TLS,
|
||||||
|
BindUsername: testBindUsername,
|
||||||
|
BindPassword: testBindPassword,
|
||||||
|
UserSearch: UserSearchConfig{
|
||||||
|
Base: testUserSearchBase,
|
||||||
|
UIDAttribute: testUserSearchUIDAttribute,
|
||||||
|
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||||
|
},
|
||||||
|
GroupSearch: GroupSearchConfig{
|
||||||
|
Base: testGroupSearchBase,
|
||||||
|
Filter: testGroupSearchFilter,
|
||||||
|
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||||
|
},
|
||||||
|
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||||
|
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||||
|
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||||
|
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||||
|
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
|
||||||
|
Entries: []*ldap.Entry{
|
||||||
|
{
|
||||||
|
DN: testGroupSearchResultDNValue1,
|
||||||
|
Attributes: []*ldap.EntryAttribute{
|
||||||
|
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue1}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
DN: testGroupSearchResultDNValue2,
|
||||||
|
Attributes: []*ldap.EntryAttribute{
|
||||||
|
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue2}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Referrals: []string{}, // note that we are not following referrals at this time
|
||||||
|
Controls: []ldap.Control{},
|
||||||
|
}, nil).Times(1)
|
||||||
|
conn.EXPECT().Close().Times(1)
|
||||||
|
},
|
||||||
|
wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "happy path where group search returns no groups",
|
||||||
|
providerConfig: &ProviderConfig{
|
||||||
|
Name: "some-provider-name",
|
||||||
|
Host: testHost,
|
||||||
|
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||||
|
ConnectionProtocol: TLS,
|
||||||
|
BindUsername: testBindUsername,
|
||||||
|
BindPassword: testBindPassword,
|
||||||
|
UserSearch: UserSearchConfig{
|
||||||
|
Base: testUserSearchBase,
|
||||||
|
UIDAttribute: testUserSearchUIDAttribute,
|
||||||
|
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||||
|
},
|
||||||
|
GroupSearch: GroupSearchConfig{
|
||||||
|
Base: testGroupSearchBase,
|
||||||
|
Filter: testGroupSearchFilter,
|
||||||
|
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||||
|
},
|
||||||
|
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||||
|
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||||
|
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||||
|
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||||
|
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
|
||||||
|
Entries: []*ldap.Entry{},
|
||||||
|
Referrals: []string{}, // note that we are not following referrals at this time
|
||||||
|
Controls: []ldap.Control{},
|
||||||
|
}, nil).Times(1)
|
||||||
|
conn.EXPECT().Close().Times(1)
|
||||||
|
},
|
||||||
|
wantGroups: []string{},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "error where dial fails",
|
name: "error where dial fails",
|
||||||
providerConfig: providerConfig,
|
providerConfig: providerConfig,
|
||||||
@ -1421,6 +1517,37 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login",
|
wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "group search returns an error",
|
||||||
|
providerConfig: &ProviderConfig{
|
||||||
|
Name: "some-provider-name",
|
||||||
|
Host: testHost,
|
||||||
|
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
|
||||||
|
ConnectionProtocol: TLS,
|
||||||
|
BindUsername: testBindUsername,
|
||||||
|
BindPassword: testBindPassword,
|
||||||
|
UserSearch: UserSearchConfig{
|
||||||
|
Base: testUserSearchBase,
|
||||||
|
UIDAttribute: testUserSearchUIDAttribute,
|
||||||
|
UsernameAttribute: testUserSearchUsernameAttribute,
|
||||||
|
},
|
||||||
|
GroupSearch: GroupSearchConfig{
|
||||||
|
Base: testGroupSearchBase,
|
||||||
|
Filter: testGroupSearchFilter,
|
||||||
|
GroupNameAttribute: testGroupSearchGroupNameAttribute,
|
||||||
|
},
|
||||||
|
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
|
||||||
|
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
setupMocks: func(conn *mockldapconn.MockConn) {
|
||||||
|
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
|
||||||
|
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
|
||||||
|
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(nil, errors.New("some search error")).Times(1)
|
||||||
|
conn.EXPECT().Close().Times(1)
|
||||||
|
},
|
||||||
|
wantErr: "error searching for group memberships for user with DN \"some-upstream-user-dn\": some search error",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -1435,9 +1562,9 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
dialWasAttempted := false
|
dialWasAttempted := false
|
||||||
providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||||
dialWasAttempted = true
|
dialWasAttempted = true
|
||||||
require.Equal(t, providerConfig.Host, addr.Endpoint())
|
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
|
||||||
if tt.dialError != nil {
|
if tt.dialError != nil {
|
||||||
return nil, tt.dialError
|
return nil, tt.dialError
|
||||||
}
|
}
|
||||||
@ -1446,9 +1573,9 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
|
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
|
||||||
ldapProvider := New(*providerConfig)
|
ldapProvider := New(*tt.providerConfig)
|
||||||
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
|
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
|
||||||
err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
|
groups, err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
|
||||||
Username: testUserSearchResultUsernameAttributeValue,
|
Username: testUserSearchResultUsernameAttributeValue,
|
||||||
Subject: subject,
|
Subject: subject,
|
||||||
DN: testUserSearchResultDNValue,
|
DN: testUserSearchResultDNValue,
|
||||||
@ -1461,6 +1588,7 @@ func TestUpstreamRefresh(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
require.Equal(t, true, dialWasAttempted)
|
require.Equal(t, true, dialWasAttempted)
|
||||||
|
require.Equal(t, tt.wantGroups, groups)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -257,7 +257,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
|
|||||||
p.GroupSearch.Base = "ou=users,dc=pinniped,dc=dev" // there are no groups under this part of the tree
|
p.GroupSearch.Base = "ou=users,dc=pinniped,dc=dev" // there are no groups under this part of the tree
|
||||||
})),
|
})),
|
||||||
wantAuthResponse: &authenticators.Response{
|
wantAuthResponse: &authenticators.Response{
|
||||||
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil},
|
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
|
||||||
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
|
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
|
||||||
ExtraRefreshAttributes: map[string]string{},
|
ExtraRefreshAttributes: map[string]string{},
|
||||||
},
|
},
|
||||||
@ -328,7 +328,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
|
|||||||
p.GroupSearch.Filter = "foobar={}" // foobar is not a valid attribute name for this LDAP server's schema
|
p.GroupSearch.Filter = "foobar={}" // foobar is not a valid attribute name for this LDAP server's schema
|
||||||
})),
|
})),
|
||||||
wantAuthResponse: &authenticators.Response{
|
wantAuthResponse: &authenticators.Response{
|
||||||
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil},
|
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
|
||||||
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
|
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
|
||||||
ExtraRefreshAttributes: map[string]string{},
|
ExtraRefreshAttributes: map[string]string{},
|
||||||
},
|
},
|
||||||
|
@ -66,6 +66,9 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
|||||||
// Either revoke the user's session on the upstream provider, or manipulate the user's session
|
// Either revoke the user's session on the upstream provider, or manipulate the user's session
|
||||||
// data in such a way that it should cause the next upstream refresh attempt to fail.
|
// data in such a way that it should cause the next upstream refresh attempt to fail.
|
||||||
breakRefreshSessionData func(t *testing.T, sessionData *psession.PinnipedSession, idpName, username string)
|
breakRefreshSessionData func(t *testing.T, sessionData *psession.PinnipedSession, idpName, username string)
|
||||||
|
// Edit the refresh session data between the initial login and the refresh, which is expected to
|
||||||
|
// succeed.
|
||||||
|
editRefreshSessionDataWithoutBreaking func(t *testing.T, sessionData *psession.PinnipedSession)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "oidc with default username and groups claim settings",
|
name: "oidc with default username and groups claim settings",
|
||||||
@ -128,6 +131,14 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
|||||||
wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Issuer+"?sub=") + ".+",
|
wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Issuer+"?sub=") + ".+",
|
||||||
wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" },
|
wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" },
|
||||||
wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups,
|
wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups,
|
||||||
|
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession) {
|
||||||
|
// even if we update this group to the wrong thing, we expect that it will return to the correct
|
||||||
|
// value after we refresh.
|
||||||
|
// However if there are no expected groups then they will not update, so we should skip this.
|
||||||
|
if len(env.SupervisorUpstreamOIDC.ExpectedGroups) > 0 {
|
||||||
|
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "oidc without refresh token",
|
name: "oidc without refresh token",
|
||||||
@ -272,6 +283,11 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
|||||||
false,
|
false,
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
|
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession) {
|
||||||
|
// even if we update this group to the wrong thing, we expect that it will return to the correct
|
||||||
|
// value after we refresh.
|
||||||
|
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
|
||||||
|
},
|
||||||
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
|
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
|
||||||
customSessionData := pinnipedSession.Custom
|
customSessionData := pinnipedSession.Custom
|
||||||
require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType)
|
require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType)
|
||||||
@ -1272,6 +1288,7 @@ func TestSupervisorLogin_Browser(t *testing.T) {
|
|||||||
testSupervisorLogin(t,
|
testSupervisorLogin(t,
|
||||||
tt.createIDP,
|
tt.createIDP,
|
||||||
tt.requestAuthorization,
|
tt.requestAuthorization,
|
||||||
|
tt.editRefreshSessionDataWithoutBreaking,
|
||||||
tt.breakRefreshSessionData,
|
tt.breakRefreshSessionData,
|
||||||
tt.createTestUser,
|
tt.createTestUser,
|
||||||
tt.deleteTestUser,
|
tt.deleteTestUser,
|
||||||
@ -1405,8 +1422,9 @@ func requireEventuallySuccessfulActiveDirectoryIdentityProviderConditions(t *tes
|
|||||||
func testSupervisorLogin(
|
func testSupervisorLogin(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
createIDP func(t *testing.T) string,
|
createIDP func(t *testing.T) string,
|
||||||
requestAuthorization func(t *testing.T, downstreamAuthorizeURL, downstreamCallbackURL, username, password string, httpClient *http.Client),
|
requestAuthorization func(t *testing.T, downstreamAuthorizeURL string, downstreamCallbackURL string, username string, password string, httpClient *http.Client),
|
||||||
breakRefreshSessionData func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName, username string),
|
editRefreshSessionDataWithoutBreaking func(t *testing.T, pinnipedSession *psession.PinnipedSession),
|
||||||
|
breakRefreshSessionData func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName string, username string),
|
||||||
createTestUser func(t *testing.T) (string, string),
|
createTestUser func(t *testing.T) (string, string),
|
||||||
deleteTestUser func(t *testing.T, username string),
|
deleteTestUser func(t *testing.T, username string),
|
||||||
wantDownstreamIDTokenSubjectToMatch string,
|
wantDownstreamIDTokenSubjectToMatch string,
|
||||||
@ -1565,6 +1583,28 @@ func testSupervisorLogin(
|
|||||||
// token exchange on the original token
|
// token exchange on the original token
|
||||||
doTokenExchange(t, &downstreamOAuth2Config, tokenResponse, httpClient, discovery)
|
doTokenExchange(t, &downstreamOAuth2Config, tokenResponse, httpClient, discovery)
|
||||||
|
|
||||||
|
if editRefreshSessionDataWithoutBreaking != nil {
|
||||||
|
latestRefreshToken := tokenResponse.RefreshToken
|
||||||
|
signatureOfLatestRefreshToken := getFositeDataSignature(t, latestRefreshToken)
|
||||||
|
|
||||||
|
// First use the latest downstream refresh token to look up the corresponding session in the Supervisor's storage.
|
||||||
|
kubeClient := testlib.NewKubernetesClientset(t)
|
||||||
|
supervisorSecretsClient := kubeClient.CoreV1().Secrets(env.SupervisorNamespace)
|
||||||
|
oauthStore := oidc.NewKubeStorage(supervisorSecretsClient, oidc.DefaultOIDCTimeoutsConfiguration())
|
||||||
|
storedRefreshSession, err := oauthStore.GetRefreshTokenSession(ctx, signatureOfLatestRefreshToken, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Next mutate the part of the session that is used during upstream refresh.
|
||||||
|
pinnipedSession, ok := storedRefreshSession.GetSession().(*psession.PinnipedSession)
|
||||||
|
require.True(t, ok, "should have been able to cast session data to PinnipedSession")
|
||||||
|
|
||||||
|
editRefreshSessionDataWithoutBreaking(t, pinnipedSession)
|
||||||
|
|
||||||
|
// Then save the mutated Secret back to Kubernetes.
|
||||||
|
// There is no update function, so delete and create again at the same name.
|
||||||
|
require.NoError(t, oauthStore.DeleteRefreshTokenSession(ctx, signatureOfLatestRefreshToken))
|
||||||
|
require.NoError(t, oauthStore.CreateRefreshTokenSession(ctx, signatureOfLatestRefreshToken, storedRefreshSession))
|
||||||
|
}
|
||||||
// Use the refresh token to get new tokens
|
// Use the refresh token to get new tokens
|
||||||
refreshSource := downstreamOAuth2Config.TokenSource(oidcHTTPClientContext, &oauth2.Token{RefreshToken: tokenResponse.RefreshToken})
|
refreshSource := downstreamOAuth2Config.TokenSource(oidcHTTPClientContext, &oauth2.Token{RefreshToken: tokenResponse.RefreshToken})
|
||||||
refreshedTokenResponse, err := refreshSource.Token()
|
refreshedTokenResponse, err := refreshSource.Token()
|
||||||
|
Loading…
Reference in New Issue
Block a user