Always update groups even if it's nil

Also de-dup groups and various small formatting changes
This commit is contained in:
Margo Crawford 2022-02-14 14:01:21 -08:00
parent c28602f275
commit ca523b1f20
5 changed files with 24 additions and 26 deletions

View File

@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface {
authenticators.UserAuthenticator authenticators.UserAuthenticator
// PerformRefresh performs a refresh against the upstream LDAP identity provider // PerformRefresh performs a refresh against the upstream LDAP identity provider
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) ([]string, error) PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) (groups []string, err error)
} }
type StoredRefreshAttributes struct { type StoredRefreshAttributes struct {

View File

@ -313,9 +313,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
} }
// If we got groups back, then replace the old value with the new value. // If we got groups back, then replace the old value with the new value.
if groups != nil { session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
}
return nil return nil
} }

View File

@ -13,12 +13,12 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"sort"
"strings" "strings"
"time" "time"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/utils/trace" "k8s.io/utils/trace"
@ -230,15 +230,11 @@ func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes p
} }
} }
// If we have group search configured, search for groups to update the value. mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN)
if len(p.c.GroupSearch.Base) > 0 { if err != nil {
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN) return nil, err
if err != nil {
return nil, err
}
return mappedGroupNames, nil
} }
return nil, nil return mappedGroupNames, nil
} }
func (p *Provider) performUserRefreshSearch(conn Conn, userDN string) (*ldap.SearchResult, error) { func (p *Provider) performUserRefreshSearch(conn Conn, userDN string) (*ldap.SearchResult, error) {
@ -453,6 +449,11 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bi
} }
func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, error) { func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, error) {
// If we do not have group search configured, skip this search.
if len(p.c.GroupSearch.Base) == 0 {
return []string{}, nil
}
searchResult, err := conn.SearchWithPaging(p.groupSearchRequest(userDN), groupSearchPageSize) searchResult, err := conn.SearchWithPaging(p.groupSearchRequest(userDN), groupSearchPageSize)
if err != nil { if err != nil {
return nil, fmt.Errorf(`error searching for group memberships for user with DN %q: %w`, userDN, err) return nil, fmt.Errorf(`error searching for group memberships for user with DN %q: %w`, userDN, err)
@ -484,8 +485,9 @@ entries:
} }
groups = append(groups, mappedGroupName) groups = append(groups, mappedGroupName)
} }
sort.Strings(groups) // de-duplicate the list of groups by turning it into a set,
return groups, nil // then turn it back into a sorted list.
return sets.NewString(groups...).List(), nil
} }
func (p *Provider) validateConfig() error { func (p *Provider) validateConfig() error {
@ -575,12 +577,9 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c
return nil, err return nil, err
} }
var mappedGroupNames []string mappedGroupNames, err := p.searchGroupsForUserDN(conn, userEntry.DN)
if len(p.c.GroupSearch.Base) > 0 { if err != nil {
mappedGroupNames, err = p.searchGroupsForUserDN(conn, userEntry.DN) return nil, err
if err != nil {
return nil, err
}
} }
mappedRefreshAttributes := make(map[string]string) mappedRefreshAttributes := make(map[string]string)

View File

@ -243,7 +243,7 @@ func TestEndUserAuthentication(t *testing.T) {
password: testUpstreamPassword, password: testUpstreamPassword,
providerConfig: providerConfig(func(p *ProviderConfig) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.GroupAttributeParsingOverrides = map[string]func(*ldap.Entry) (string, error){testGroupSearchGroupNameAttribute: func(entry *ldap.Entry) (string, error) { p.GroupAttributeParsingOverrides = map[string]func(*ldap.Entry) (string, error){testGroupSearchGroupNameAttribute: func(entry *ldap.Entry) (string, error) {
return "something-else", nil return "something-else-" + entry.DN, nil
}} }}
}), }),
searchMocks: func(conn *mockldapconn.MockConn) { searchMocks: func(conn *mockldapconn.MockConn) {
@ -260,7 +260,7 @@ func TestEndUserAuthentication(t *testing.T) {
r.User = &user.DefaultInfo{ r.User = &user.DefaultInfo{
Name: testUserSearchResultUsernameAttributeValue, Name: testUserSearchResultUsernameAttributeValue,
UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)), UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)),
Groups: []string{"something-else", "something-else"}, Groups: []string{"something-else-some-upstream-group-dn1", "something-else-some-upstream-group-dn2"},
} }
}), }),
}, },
@ -281,7 +281,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, },
wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) {
info := r.User.(*user.DefaultInfo) info := r.User.(*user.DefaultInfo)
info.Groups = nil info.Groups = []string{}
}), }),
}, },
{ {
@ -1213,6 +1213,7 @@ func TestUpstreamRefresh(t *testing.T) {
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantGroups: []string{},
}, },
{ {
name: "happy path where group search returns groups", name: "happy path where group search returns groups",

View File

@ -244,7 +244,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.Base = "" p.GroupSearch.Base = ""
})), })),
wantAuthResponse: &authenticators.Response{ wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil}, User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{}, ExtraRefreshAttributes: map[string]string{},
}, },
@ -302,7 +302,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test
})), })),
wantAuthResponse: &authenticators.Response{ wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames", "groupOfNames"}}, User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames"}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{}, ExtraRefreshAttributes: map[string]string{},
}, },