From ca523b1f20e1b1e7004155b7b6afd113d6b5ec8c Mon Sep 17 00:00:00 2001 From: Margo Crawford Date: Mon, 14 Feb 2022 14:01:21 -0800 Subject: [PATCH] Always update groups even if it's nil Also de-dup groups and various small formatting changes --- .../provider/dynamic_upstream_idp_provider.go | 2 +- internal/oidc/token/token_handler.go | 4 +-- internal/upstreamldap/upstreamldap.go | 33 +++++++++---------- internal/upstreamldap/upstreamldap_test.go | 7 ++-- test/integration/ldap_client_test.go | 4 +-- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 2a06a0ba..6f0ced3d 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface { authenticators.UserAuthenticator // PerformRefresh performs a refresh against the upstream LDAP identity provider - PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) ([]string, error) + PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) (groups []string, err error) } type StoredRefreshAttributes struct { diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index e434efc9..e7451f7b 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -313,9 +313,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) } // If we got groups back, then replace the old value with the new value. - if groups != nil { - session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups - } + session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups return nil } diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index c1f14fb7..6ea38b6b 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -13,12 +13,12 @@ import ( "fmt" "net" "net/url" - "sort" "strings" "time" "github.com/go-ldap/ldap/v3" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/utils/trace" @@ -230,15 +230,11 @@ func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes p } } - // If we have group search configured, search for groups to update the value. - if len(p.c.GroupSearch.Base) > 0 { - mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN) - if err != nil { - return nil, err - } - return mappedGroupNames, nil + mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN) + if err != nil { + return nil, err } - return nil, nil + return mappedGroupNames, nil } func (p *Provider) performUserRefreshSearch(conn Conn, userDN string) (*ldap.SearchResult, error) { @@ -453,6 +449,11 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bi } func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, error) { + // If we do not have group search configured, skip this search. + if len(p.c.GroupSearch.Base) == 0 { + return []string{}, nil + } + searchResult, err := conn.SearchWithPaging(p.groupSearchRequest(userDN), groupSearchPageSize) if err != nil { return nil, fmt.Errorf(`error searching for group memberships for user with DN %q: %w`, userDN, err) @@ -484,8 +485,9 @@ entries: } groups = append(groups, mappedGroupName) } - sort.Strings(groups) - return groups, nil + // de-duplicate the list of groups by turning it into a set, + // then turn it back into a sorted list. + return sets.NewString(groups...).List(), nil } func (p *Provider) validateConfig() error { @@ -575,12 +577,9 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c return nil, err } - var mappedGroupNames []string - if len(p.c.GroupSearch.Base) > 0 { - mappedGroupNames, err = p.searchGroupsForUserDN(conn, userEntry.DN) - if err != nil { - return nil, err - } + mappedGroupNames, err := p.searchGroupsForUserDN(conn, userEntry.DN) + if err != nil { + return nil, err } mappedRefreshAttributes := make(map[string]string) diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 5a028e51..8b899568 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -243,7 +243,7 @@ func TestEndUserAuthentication(t *testing.T) { password: testUpstreamPassword, providerConfig: providerConfig(func(p *ProviderConfig) { p.GroupAttributeParsingOverrides = map[string]func(*ldap.Entry) (string, error){testGroupSearchGroupNameAttribute: func(entry *ldap.Entry) (string, error) { - return "something-else", nil + return "something-else-" + entry.DN, nil }} }), searchMocks: func(conn *mockldapconn.MockConn) { @@ -260,7 +260,7 @@ func TestEndUserAuthentication(t *testing.T) { r.User = &user.DefaultInfo{ Name: testUserSearchResultUsernameAttributeValue, UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)), - Groups: []string{"something-else", "something-else"}, + Groups: []string{"something-else-some-upstream-group-dn1", "something-else-some-upstream-group-dn2"}, } }), }, @@ -281,7 +281,7 @@ func TestEndUserAuthentication(t *testing.T) { }, wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { info := r.User.(*user.DefaultInfo) - info.Groups = nil + info.Groups = []string{} }), }, { @@ -1213,6 +1213,7 @@ func TestUpstreamRefresh(t *testing.T) { conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) conn.EXPECT().Close().Times(1) }, + wantGroups: []string{}, }, { name: "happy path where group search returns groups", diff --git a/test/integration/ldap_client_test.go b/test/integration/ldap_client_test.go index 923acb4a..604e5e2e 100644 --- a/test/integration/ldap_client_test.go +++ b/test/integration/ldap_client_test.go @@ -244,7 +244,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.Base = "" })), wantAuthResponse: &authenticators.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}}, DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", ExtraRefreshAttributes: map[string]string{}, }, @@ -302,7 +302,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test })), wantAuthResponse: &authenticators.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames", "groupOfNames"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames"}}, DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", ExtraRefreshAttributes: map[string]string{}, },