Always update groups even if it's nil

Also de-dup groups and various small formatting changes
This commit is contained in:
Margo Crawford 2022-02-14 14:01:21 -08:00
parent c28602f275
commit ca523b1f20
5 changed files with 24 additions and 26 deletions

View File

@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface {
authenticators.UserAuthenticator
// PerformRefresh performs a refresh against the upstream LDAP identity provider
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) ([]string, error)
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) (groups []string, err error)
}
type StoredRefreshAttributes struct {

View File

@ -313,9 +313,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
// If we got groups back, then replace the old value with the new value.
if groups != nil {
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
}
return nil
}

View File

@ -13,12 +13,12 @@ import (
"fmt"
"net"
"net/url"
"sort"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/authentication/user"
"k8s.io/utils/trace"
@ -230,15 +230,11 @@ func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes p
}
}
// If we have group search configured, search for groups to update the value.
if len(p.c.GroupSearch.Base) > 0 {
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN)
if err != nil {
return nil, err
}
return mappedGroupNames, nil
}
return nil, nil
}
func (p *Provider) performUserRefreshSearch(conn Conn, userDN string) (*ldap.SearchResult, error) {
@ -453,6 +449,11 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bi
}
func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, error) {
// If we do not have group search configured, skip this search.
if len(p.c.GroupSearch.Base) == 0 {
return []string{}, nil
}
searchResult, err := conn.SearchWithPaging(p.groupSearchRequest(userDN), groupSearchPageSize)
if err != nil {
return nil, fmt.Errorf(`error searching for group memberships for user with DN %q: %w`, userDN, err)
@ -484,8 +485,9 @@ entries:
}
groups = append(groups, mappedGroupName)
}
sort.Strings(groups)
return groups, nil
// de-duplicate the list of groups by turning it into a set,
// then turn it back into a sorted list.
return sets.NewString(groups...).List(), nil
}
func (p *Provider) validateConfig() error {
@ -575,13 +577,10 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c
return nil, err
}
var mappedGroupNames []string
if len(p.c.GroupSearch.Base) > 0 {
mappedGroupNames, err = p.searchGroupsForUserDN(conn, userEntry.DN)
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userEntry.DN)
if err != nil {
return nil, err
}
}
mappedRefreshAttributes := make(map[string]string)
for k := range p.c.RefreshAttributeChecks {

View File

@ -243,7 +243,7 @@ func TestEndUserAuthentication(t *testing.T) {
password: testUpstreamPassword,
providerConfig: providerConfig(func(p *ProviderConfig) {
p.GroupAttributeParsingOverrides = map[string]func(*ldap.Entry) (string, error){testGroupSearchGroupNameAttribute: func(entry *ldap.Entry) (string, error) {
return "something-else", nil
return "something-else-" + entry.DN, nil
}}
}),
searchMocks: func(conn *mockldapconn.MockConn) {
@ -260,7 +260,7 @@ func TestEndUserAuthentication(t *testing.T) {
r.User = &user.DefaultInfo{
Name: testUserSearchResultUsernameAttributeValue,
UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)),
Groups: []string{"something-else", "something-else"},
Groups: []string{"something-else-some-upstream-group-dn1", "something-else-some-upstream-group-dn2"},
}
}),
},
@ -281,7 +281,7 @@ func TestEndUserAuthentication(t *testing.T) {
},
wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) {
info := r.User.(*user.DefaultInfo)
info.Groups = nil
info.Groups = []string{}
}),
},
{
@ -1213,6 +1213,7 @@ func TestUpstreamRefresh(t *testing.T) {
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantGroups: []string{},
},
{
name: "happy path where group search returns groups",

View File

@ -244,7 +244,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.Base = ""
})),
wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil},
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{},
},
@ -302,7 +302,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test
})),
wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames", "groupOfNames"}},
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames"}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{},
},