Introduce upstreamldap.New to prevent changes to the underlying config

Makes it easier to support using the same upstreamldap.Provider from
multiple goroutines safely.
This commit is contained in:
Ryan Richard 2021-04-15 10:25:35 -07:00
parent 5c28d36c9b
commit e6e6497022
5 changed files with 296 additions and 194 deletions

View File

@ -107,10 +107,10 @@ func (c *ldapWatcherController) Sync(ctx controllerlib.Context) error {
func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.LDAPIdentityProvider) provider.UpstreamLDAPIdentityProviderI { func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.LDAPIdentityProvider) provider.UpstreamLDAPIdentityProviderI {
spec := upstream.Spec spec := upstream.Spec
result := &upstreamldap.Provider{ config := &upstreamldap.ProviderConfig{
Name: upstream.Name, Name: upstream.Name,
Host: spec.Host, Host: spec.Host,
UserSearch: &upstreamldap.UserSearch{ UserSearch: upstreamldap.UserSearchConfig{
Base: spec.UserSearch.Base, Base: spec.UserSearch.Base,
Filter: spec.UserSearch.Filter, Filter: spec.UserSearch.Filter,
UsernameAttribute: spec.UserSearch.Attributes.Username, UsernameAttribute: spec.UserSearch.Attributes.Username,
@ -119,17 +119,17 @@ func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *
Dialer: c.ldapDialer, Dialer: c.ldapDialer,
} }
conditions := []*v1alpha1.Condition{ conditions := []*v1alpha1.Condition{
c.validateSecret(upstream, result), c.validateSecret(upstream, config),
c.validateTLSConfig(upstream, result), c.validateTLSConfig(upstream, config),
} }
hadErrorCondition := c.updateStatus(ctx, upstream, conditions) hadErrorCondition := c.updateStatus(ctx, upstream, conditions)
if hadErrorCondition { if hadErrorCondition {
return nil return nil
} }
return result return upstreamldap.New(*config)
} }
func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.Provider) *v1alpha1.Condition { func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.ProviderConfig) *v1alpha1.Condition {
tlsSpec := upstream.Spec.TLS tlsSpec := upstream.Spec.TLS
if tlsSpec == nil { if tlsSpec == nil {
return c.validTLSCondition(noTLSConfigurationMessage) return c.validTLSCondition(noTLSConfigurationMessage)
@ -171,7 +171,7 @@ func (c *ldapWatcherController) invalidTLSCondition(message string) *v1alpha1.Co
} }
} }
func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.Provider) *v1alpha1.Condition { func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.ProviderConfig) *v1alpha1.Condition {
secretName := upstream.Spec.Bind.SecretName secretName := upstream.Spec.Bind.SecretName
secret, err := c.secretInformer.Lister().Secrets(upstream.Namespace).Get(secretName) secret, err := c.secretInformer.Lister().Secrets(upstream.Namespace).Get(secretName)

View File

@ -194,13 +194,13 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
return deepCopy return deepCopy
} }
providerForValidUpstream := &upstreamldap.Provider{ providerConfigForValidUpstream := &upstreamldap.ProviderConfig{
Name: testName, Name: testName,
Host: testHost, Host: testHost,
CABundle: testCABundle, CABundle: testCABundle,
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,
UserSearch: &upstreamldap.UserSearch{ UserSearch: upstreamldap.UserSearchConfig{
Base: testUserSearchBase, Base: testUserSearchBase,
Filter: testUserSearchFilter, Filter: testUserSearchFilter,
UsernameAttribute: testUsernameAttrName, UsernameAttribute: testUsernameAttrName,
@ -215,7 +215,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
inputSecrets []runtime.Object inputSecrets []runtime.Object
ldapDialer upstreamldap.LDAPDialer ldapDialer upstreamldap.LDAPDialer
wantErr string wantErr string
wantResultingCache []*upstreamldap.Provider wantResultingCache []*upstreamldap.ProviderConfig
wantResultingUpstreams []v1alpha1.LDAPIdentityProvider wantResultingUpstreams []v1alpha1.LDAPIdentityProvider
}{ }{
{ {
@ -230,7 +230,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Type: corev1.SecretTypeBasicAuth, Type: corev1.SecretTypeBasicAuth,
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantResultingCache: []*upstreamldap.Provider{providerForValidUpstream}, wantResultingCache: []*upstreamldap.ProviderConfig{providerConfigForValidUpstream},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{ Status: v1alpha1.LDAPIdentityProviderStatus{
@ -262,7 +262,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
inputUpstreams: []runtime.Object{validUpstream}, inputUpstreams: []runtime.Object{validUpstream},
inputSecrets: []runtime.Object{}, inputSecrets: []runtime.Object{},
wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantErr: controllerlib.ErrSyntheticRequeue.Error(),
wantResultingCache: []*upstreamldap.Provider{}, wantResultingCache: []*upstreamldap.ProviderConfig{},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{ Status: v1alpha1.LDAPIdentityProviderStatus{
@ -298,7 +298,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantErr: controllerlib.ErrSyntheticRequeue.Error(),
wantResultingCache: []*upstreamldap.Provider{}, wantResultingCache: []*upstreamldap.ProviderConfig{},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{ Status: v1alpha1.LDAPIdentityProviderStatus{
@ -333,7 +333,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Type: corev1.SecretTypeBasicAuth, Type: corev1.SecretTypeBasicAuth,
}}, }},
wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantErr: controllerlib.ErrSyntheticRequeue.Error(),
wantResultingCache: []*upstreamldap.Provider{}, wantResultingCache: []*upstreamldap.ProviderConfig{},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{ Status: v1alpha1.LDAPIdentityProviderStatus{
@ -371,7 +371,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantErr: controllerlib.ErrSyntheticRequeue.Error(),
wantResultingCache: []*upstreamldap.Provider{}, wantResultingCache: []*upstreamldap.ProviderConfig{},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{ Status: v1alpha1.LDAPIdentityProviderStatus{
@ -409,7 +409,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantErr: controllerlib.ErrSyntheticRequeue.Error(),
wantResultingCache: []*upstreamldap.Provider{}, wantResultingCache: []*upstreamldap.ProviderConfig{},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{ Status: v1alpha1.LDAPIdentityProviderStatus{
@ -446,14 +446,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Type: corev1.SecretTypeBasicAuth, Type: corev1.SecretTypeBasicAuth,
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantResultingCache: []*upstreamldap.Provider{ wantResultingCache: []*upstreamldap.ProviderConfig{
{ {
Name: testName, Name: testName,
Host: testHost, Host: testHost,
CABundle: nil, CABundle: nil,
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,
UserSearch: &upstreamldap.UserSearch{ UserSearch: upstreamldap.UserSearchConfig{
Base: testUserSearchBase, Base: testUserSearchBase,
Filter: testUserSearchFilter, Filter: testUserSearchFilter,
UsernameAttribute: testUsernameAttrName, UsernameAttribute: testUsernameAttrName,
@ -498,14 +498,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Type: corev1.SecretTypeBasicAuth, Type: corev1.SecretTypeBasicAuth,
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantResultingCache: []*upstreamldap.Provider{ wantResultingCache: []*upstreamldap.ProviderConfig{
{ {
Name: testName, Name: testName,
Host: testHost, Host: testHost,
CABundle: nil, CABundle: nil,
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,
UserSearch: &upstreamldap.UserSearch{ UserSearch: upstreamldap.UserSearchConfig{
Base: testUserSearchBase, Base: testUserSearchBase,
Filter: testUserSearchFilter, Filter: testUserSearchFilter,
UsernameAttribute: testUsernameAttrName, UsernameAttribute: testUsernameAttrName,
@ -553,7 +553,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantErr: controllerlib.ErrSyntheticRequeue.Error(),
wantResultingCache: []*upstreamldap.Provider{providerForValidUpstream}, wantResultingCache: []*upstreamldap.ProviderConfig{providerConfigForValidUpstream},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{
{ {
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: "other-upstream", Generation: 42}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: "other-upstream", Generation: 42},
@ -616,7 +616,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{ cache.SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{
&upstreamldap.Provider{Name: "initial-entry"}, upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}),
}) })
controller := NewLDAPUpstreamWatcherController( controller := NewLDAPUpstreamWatcherController(
@ -647,7 +647,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
for i := range actualIDPList { for i := range actualIDPList {
actualIDP := actualIDPList[i].(*upstreamldap.Provider) actualIDP := actualIDPList[i].(*upstreamldap.Provider)
require.Equal(t, tt.wantResultingCache[i], actualIDP) require.Equal(t, *tt.wantResultingCache[i], actualIDP.GetConfig())
} }
actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{}) actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{})

View File

@ -50,9 +50,10 @@ func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, err
return f(ctx, hostAndPort) return f(ctx, hostAndPort)
} }
// Provider includes all of the settings for connection and searching for users and groups in // ProviderConfig includes all of the settings for connection and searching for users and groups in
// the upstream LDAP IDP. It also provides methods for testing the connection and performing logins. // the upstream LDAP IDP. It also provides methods for testing the connection and performing logins.
type Provider struct { // The nested structs are not pointer fields to enable deep copy on function params and return values.
type ProviderConfig struct {
// Name is the unique name of this upstream LDAP IDP. // Name is the unique name of this upstream LDAP IDP.
Name string Name string
@ -70,14 +71,14 @@ type Provider struct {
BindPassword string BindPassword string
// UserSearch contains information about how to search for users in the upstream LDAP IDP. // UserSearch contains information about how to search for users in the upstream LDAP IDP.
UserSearch *UserSearch UserSearch UserSearchConfig
// Dialer exists to enable testing. When nil, will use a default appropriate for production use. // Dialer exists to enable testing. When nil, will use a default appropriate for production use.
Dialer LDAPDialer Dialer LDAPDialer
} }
// UserSearch contains information about how to search for users in the upstream LDAP IDP. // UserSearchConfig contains information about how to search for users in the upstream LDAP IDP.
type UserSearch struct { type UserSearchConfig struct {
// Base is the base DN to use for the user search in the upstream LDAP IDP. // Base is the base DN to use for the user search in the upstream LDAP IDP.
Base string Base string
@ -93,13 +94,28 @@ type UserSearch struct {
UIDAttribute string UIDAttribute string
} }
type Provider struct {
c ProviderConfig
}
// Create a Provider. The config is not a pointer to ensure that a copy of the config is created,
// making the resulting Provider use an effectively read-only configuration.
func New(config ProviderConfig) *Provider {
return &Provider{c: config}
}
// A reader for the config. Returns a copy of the config to keep the underlying config read-only.
func (p *Provider) GetConfig() ProviderConfig {
return p.c
}
func (p *Provider) dial(ctx context.Context) (Conn, error) { func (p *Provider) dial(ctx context.Context) (Conn, error) {
hostAndPort, err := hostAndPortWithDefaultPort(p.Host, ldap.DefaultLdapsPort) hostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort)
if err != nil { if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err) return nil, ldap.NewError(ldap.ErrorNetwork, err)
} }
if p.Dialer != nil { if p.c.Dialer != nil {
return p.Dialer.Dial(ctx, hostAndPort) return p.c.Dialer.Dial(ctx, hostAndPort)
} }
return p.dialTLS(ctx, hostAndPort) return p.dialTLS(ctx, hostAndPort)
} }
@ -109,8 +125,8 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) {
// so we implement it ourselves, heavily inspired by ldap.DialURL. // so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
rootCAs := x509.NewCertPool() rootCAs := x509.NewCertPool()
if p.CABundle != nil { if p.c.CABundle != nil {
if !rootCAs.AppendCertsFromPEM(p.CABundle) { if !rootCAs.AppendCertsFromPEM(p.c.CABundle) {
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle")) return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle"))
} }
} }
@ -154,14 +170,14 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string,
// A name for this upstream provider. // A name for this upstream provider.
func (p *Provider) GetName() string { func (p *Provider) GetName() string {
return p.Name return p.c.Name
} }
// Return a URL which uniquely identifies this LDAP provider, e.g. "ldaps://host.example.com:1234". // Return a URL which uniquely identifies this LDAP provider, e.g. "ldaps://host.example.com:1234".
// This URL is not used for connecting to the provider, but rather is used for creating a globally unique user // This URL is not used for connecting to the provider, but rather is used for creating a globally unique user
// identifier by being combined with the user's UID, since user UIDs are only unique within one provider. // identifier by being combined with the user's UID, since user UIDs are only unique within one provider.
func (p *Provider) GetURL() string { func (p *Provider) GetURL() string {
return fmt.Sprintf("%s://%s", ldapsScheme, p.Host) return fmt.Sprintf("%s://%s", ldapsScheme, p.c.Host)
} }
// TestConnection provides a method for testing the connection and bind settings. It performs a dial and bind // TestConnection provides a method for testing the connection and bind settings. It performs a dial and bind
@ -183,7 +199,7 @@ func (p *Provider) TestAuthenticateUser(ctx context.Context, testUsername string
// Authenticate a user and return their mapped username, groups, and UID. Implements authenticators.UserAuthenticator. // Authenticate a user and return their mapped username, groups, and UID. Implements authenticators.UserAuthenticator.
func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) { func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) {
if p.UserSearch.UsernameAttribute == distinguishedNameAttributeName && len(p.UserSearch.Filter) == 0 { if p.c.UserSearch.UsernameAttribute == distinguishedNameAttributeName && len(p.c.UserSearch.Filter) == 0 {
// LDAP search filters do not allow searching by DN. // LDAP search filters do not allow searching by DN.
return nil, false, fmt.Errorf(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`) return nil, false, fmt.Errorf(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`)
} }
@ -195,13 +211,13 @@ func (p *Provider) AuthenticateUser(ctx context.Context, username, password stri
conn, err := p.dial(ctx) conn, err := p.dial(ctx)
if err != nil { if err != nil {
return nil, false, fmt.Errorf(`error dialing host "%s": %w`, p.Host, err) return nil, false, fmt.Errorf(`error dialing host "%s": %w`, p.c.Host, err)
} }
defer conn.Close() defer conn.Close()
err = conn.Bind(p.BindUsername, p.BindPassword) err = conn.Bind(p.c.BindUsername, p.c.BindPassword)
if err != nil { if err != nil {
return nil, false, fmt.Errorf(`error binding as "%s" before user search: %w`, p.BindUsername, err) return nil, false, fmt.Errorf(`error binding as "%s" before user search: %w`, p.c.BindUsername, err)
} }
mappedUsername, mappedUID, err := p.searchAndBindUser(conn, username, password) mappedUsername, mappedUID, err := p.searchAndBindUser(conn, username, password)
@ -243,12 +259,12 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, password string
return "", "", fmt.Errorf(`searching for user "%s" resulted in search result without DN`, username) return "", "", fmt.Errorf(`searching for user "%s" resulted in search result without DN`, username)
} }
mappedUsername, err := p.getSearchResultAttributeValue(p.UserSearch.UsernameAttribute, userEntry, username) mappedUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, username)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
mappedUID, err := p.getSearchResultAttributeValue(p.UserSearch.UIDAttribute, userEntry, username) mappedUID, err := p.getSearchResultAttributeValue(p.c.UserSearch.UIDAttribute, userEntry, username)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
@ -270,7 +286,7 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, password string
func (p *Provider) userSearchRequest(username string) *ldap.SearchRequest { func (p *Provider) userSearchRequest(username string) *ldap.SearchRequest {
// See https://ldap.com/the-ldap-search-operation for general documentation of LDAP search options. // See https://ldap.com/the-ldap-search-operation for general documentation of LDAP search options.
return &ldap.SearchRequest{ return &ldap.SearchRequest{
BaseDN: p.UserSearch.Base, BaseDN: p.c.UserSearch.Base,
Scope: ldap.ScopeWholeSubtree, Scope: ldap.ScopeWholeSubtree,
DerefAliases: ldap.DerefAlways, // TODO what's the best value here? DerefAliases: ldap.DerefAlways, // TODO what's the best value here?
SizeLimit: 2, SizeLimit: 2,
@ -284,21 +300,21 @@ func (p *Provider) userSearchRequest(username string) *ldap.SearchRequest {
func (p *Provider) userSearchRequestedAttributes() []string { func (p *Provider) userSearchRequestedAttributes() []string {
attributes := []string{} attributes := []string{}
if p.UserSearch.UsernameAttribute != distinguishedNameAttributeName { if p.c.UserSearch.UsernameAttribute != distinguishedNameAttributeName {
attributes = append(attributes, p.UserSearch.UsernameAttribute) attributes = append(attributes, p.c.UserSearch.UsernameAttribute)
} }
if p.UserSearch.UIDAttribute != distinguishedNameAttributeName { if p.c.UserSearch.UIDAttribute != distinguishedNameAttributeName {
attributes = append(attributes, p.UserSearch.UIDAttribute) attributes = append(attributes, p.c.UserSearch.UIDAttribute)
} }
return attributes return attributes
} }
func (p *Provider) userSearchFilter(username string) string { func (p *Provider) userSearchFilter(username string) string {
safeUsername := p.escapeUsernameForSearchFilter(username) safeUsername := p.escapeUsernameForSearchFilter(username)
if len(p.UserSearch.Filter) == 0 { if len(p.c.UserSearch.Filter) == 0 {
return fmt.Sprintf("(%s=%s)", p.UserSearch.UsernameAttribute, safeUsername) return fmt.Sprintf("(%s=%s)", p.c.UserSearch.UsernameAttribute, safeUsername)
} }
filter := strings.ReplaceAll(p.UserSearch.Filter, userSearchFilterInterpolationLocationMarker, safeUsername) filter := strings.ReplaceAll(p.c.UserSearch.Filter, userSearchFilterInterpolationLocationMarker, safeUsername)
if strings.HasPrefix(filter, "(") && strings.HasSuffix(filter, ")") { if strings.HasPrefix(filter, "(") && strings.HasSuffix(filter, ")") {
return filter return filter
} }

View File

@ -43,14 +43,14 @@ var (
) )
func TestAuthenticateUser(t *testing.T) { func TestAuthenticateUser(t *testing.T) {
provider := func(editFunc func(p *Provider)) *Provider { providerConfig := func(editFunc func(p *ProviderConfig)) *ProviderConfig {
provider := &Provider{ config := &ProviderConfig{
Name: "some-provider-name", Name: "some-provider-name",
Host: testHost, Host: testHost,
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,
UserSearch: &UserSearch{ UserSearch: UserSearchConfig{
Base: testUserSearchBase, Base: testUserSearchBase,
Filter: testUserSearchFilter, Filter: testUserSearchFilter,
UsernameAttribute: testUserSearchUsernameAttribute, UsernameAttribute: testUserSearchUsernameAttribute,
@ -58,9 +58,9 @@ func TestAuthenticateUser(t *testing.T) {
}, },
} }
if editFunc != nil { if editFunc != nil {
editFunc(provider) editFunc(config)
} }
return provider return config
} }
expectedSearch := func(editFunc func(r *ldap.SearchRequest)) *ldap.SearchRequest { expectedSearch := func(editFunc func(r *ldap.SearchRequest)) *ldap.SearchRequest {
@ -85,7 +85,7 @@ func TestAuthenticateUser(t *testing.T) {
name string name string
username string username string
password string password string
provider *Provider providerConfig *ProviderConfig
setupMocks func(conn *mockldapconn.MockConn) setupMocks func(conn *mockldapconn.MockConn)
dialError error dialError error
wantError string wantError string
@ -94,10 +94,10 @@ func TestAuthenticateUser(t *testing.T) {
wantUnauthenticated bool wantUnauthenticated bool
}{ }{
{ {
name: "happy path", name: "happy path",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -128,7 +128,7 @@ func TestAuthenticateUser(t *testing.T) {
name: "when the user search filter is already wrapped by parenthesis then it is not wrapped again", name: "when the user search filter is already wrapped by parenthesis then it is not wrapped again",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(func(p *Provider) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.UserSearch.Filter = "(" + testUserSearchFilter + ")" p.UserSearch.Filter = "(" + testUserSearchFilter + ")"
}), }),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
@ -159,7 +159,7 @@ func TestAuthenticateUser(t *testing.T) {
name: "when the UsernameAttribute is dn and there is a user search filter provided", name: "when the UsernameAttribute is dn and there is a user search filter provided",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(func(p *Provider) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.UserSearch.UsernameAttribute = "dn" p.UserSearch.UsernameAttribute = "dn"
}), }),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
@ -191,7 +191,7 @@ func TestAuthenticateUser(t *testing.T) {
name: "when the UIDAttribute is dn", name: "when the UIDAttribute is dn",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(func(p *Provider) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.UserSearch.UIDAttribute = "dn" p.UserSearch.UIDAttribute = "dn"
}), }),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
@ -223,7 +223,7 @@ func TestAuthenticateUser(t *testing.T) {
name: "when Filter is blank it derives a search filter from the UsernameAttribute", name: "when Filter is blank it derives a search filter from the UsernameAttribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(func(p *Provider) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.UserSearch.Filter = "" p.UserSearch.Filter = ""
}), }),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
@ -253,10 +253,10 @@ func TestAuthenticateUser(t *testing.T) {
}, },
}, },
{ {
name: "when the username has special LDAP search filter characters then they must be properly escaped in the search filter", name: "when the username has special LDAP search filter characters then they must be properly escaped in the search filter",
username: `a&b|c(d)e\f*g`, username: `a&b|c(d)e\f*g`,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) { conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) {
@ -284,18 +284,18 @@ func TestAuthenticateUser(t *testing.T) {
}, },
}, },
{ {
name: "when dial fails", name: "when dial fails",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
dialError: errors.New("some dial error"), dialError: errors.New("some dial error"),
wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost),
}, },
{ {
name: "when the UsernameAttribute is dn and there is not a user search filter provided", name: "when the UsernameAttribute is dn and there is not a user search filter provided",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(func(p *Provider) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.UserSearch.UsernameAttribute = "dn" p.UserSearch.UsernameAttribute = "dn"
p.UserSearch.Filter = "" p.UserSearch.Filter = ""
}), }),
@ -303,10 +303,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`,
}, },
{ {
name: "when binding as the bind user returns an error", name: "when binding as the bind user returns an error",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
@ -314,10 +314,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`error binding as "%s" before user search: some bind error`, testBindUsername), wantError: fmt.Sprintf(`error binding as "%s" before user search: some bind error`, testBindUsername),
}, },
{ {
name: "when searching for the user returns an error", name: "when searching for the user returns an error",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(nil, errors.New("some search error")).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(nil, errors.New("some search error")).Times(1)
@ -326,10 +326,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`error searching for user "%s": some search error`, testUpstreamUsername), wantError: fmt.Sprintf(`error searching for user "%s": some search error`, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns no results", name: "when searching for the user returns no results",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -340,10 +340,10 @@ func TestAuthenticateUser(t *testing.T) {
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
{ {
name: "when searching for the user returns multiple results", name: "when searching for the user returns multiple results",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -357,10 +357,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`searching for user "%s" resulted in 2 search results, but expected 1 result`, testUpstreamUsername), wantError: fmt.Sprintf(`searching for user "%s" resulted in 2 search results, but expected 1 result`, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user without a DN", name: "when searching for the user returns a user without a DN",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -373,10 +373,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`searching for user "%s" resulted in search result without DN`, testUpstreamUsername), wantError: fmt.Sprintf(`searching for user "%s" resulted in search result without DN`, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user without an expected username attribute", name: "when searching for the user returns a user without an expected username attribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -394,10 +394,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUsernameAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUsernameAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user with too many values for the expected username attribute", name: "when searching for the user returns a user with too many values for the expected username attribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -419,10 +419,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUsernameAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUsernameAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user with an empty value for the expected username attribute", name: "when searching for the user returns a user with an empty value for the expected username attribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -441,10 +441,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUsernameAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUsernameAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user without an expected UID attribute", name: "when searching for the user returns a user without an expected UID attribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -462,10 +462,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user with too many values for the expected UID attribute", name: "when searching for the user returns a user with too many values for the expected UID attribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -487,10 +487,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user with an empty value for the expected UID attribute", name: "when searching for the user returns a user with an empty value for the expected UID attribute",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -509,10 +509,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{ {
name: "when binding as the found user returns an error", name: "when binding as the found user returns an error",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -532,10 +532,10 @@ func TestAuthenticateUser(t *testing.T) {
wantError: fmt.Sprintf(`error binding for user "%s" using provided password against DN "%s": some bind error`, testUpstreamUsername, testSearchResultDNValue), wantError: fmt.Sprintf(`error binding for user "%s" using provided password against DN "%s": some bind error`, testUpstreamUsername, testSearchResultDNValue),
}, },
{ {
name: "when binding as the found user returns a specific invalid credentials error", name: "when binding as the found user returns a specific invalid credentials error",
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) { setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{
@ -558,7 +558,7 @@ func TestAuthenticateUser(t *testing.T) {
name: "when no username is specified", name: "when no username is specified",
username: "", username: "",
password: testUpstreamPassword, password: testUpstreamPassword,
provider: provider(nil), providerConfig: providerConfig(nil),
wantToSkipDial: true, wantToSkipDial: true,
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
@ -576,16 +576,17 @@ func TestAuthenticateUser(t *testing.T) {
} }
dialWasAttempted := false dialWasAttempted := false
tt.provider.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) {
dialWasAttempted = true dialWasAttempted = true
require.Equal(t, tt.provider.Host, hostAndPort) require.Equal(t, tt.providerConfig.Host, hostAndPort)
if tt.dialError != nil { if tt.dialError != nil {
return nil, tt.dialError return nil, tt.dialError
} }
return conn, nil return conn, nil
}) })
authResponse, authenticated, err := tt.provider.AuthenticateUser(context.Background(), tt.username, tt.password) provider := New(*tt.providerConfig)
authResponse, authenticated, err := provider.AuthenticateUser(context.Background(), tt.username, tt.password)
require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted)
@ -607,9 +608,37 @@ func TestAuthenticateUser(t *testing.T) {
} }
} }
func TestGetConfig(t *testing.T) {
c := ProviderConfig{
Name: "original-provider-name",
Host: testHost,
CABundle: []byte("some-ca-bundle"),
BindUsername: testBindUsername,
BindPassword: testBindPassword,
UserSearch: UserSearchConfig{
Base: testUserSearchBase,
Filter: testUserSearchFilter,
UsernameAttribute: testUserSearchUsernameAttribute,
UIDAttribute: testUserSearchUIDAttribute,
},
}
p := New(c)
require.Equal(t, c, p.c)
require.Equal(t, c, p.GetConfig())
// The original config can be changed without impacting the provider, since the provider made a copy of the config.
c.Name = "changed-name"
require.Equal(t, "original-provider-name", p.c.Name)
// The return value of GetConfig can be modified without impacting the provider, since it is a copy of the config.
returnedConfig := p.GetConfig()
returnedConfig.Name = "changed-name"
require.Equal(t, "original-provider-name", p.c.Name)
}
func TestGetURL(t *testing.T) { func TestGetURL(t *testing.T) {
require.Equal(t, "ldaps://ldap.example.com:1234", (&Provider{Host: "ldap.example.com:1234"}).GetURL()) require.Equal(t, "ldaps://ldap.example.com:1234", New(ProviderConfig{Host: "ldap.example.com:1234"}).GetURL())
require.Equal(t, "ldaps://ldap.example.com", (&Provider{Host: "ldap.example.com"}).GetURL()) require.Equal(t, "ldaps://ldap.example.com", New(ProviderConfig{Host: "ldap.example.com"}).GetURL())
} }
// Testing of host parsing, TLS negotiation, and CA bundle, etc. for the production code's dialer. // Testing of host parsing, TLS negotiation, and CA bundle, etc. for the production code's dialer.
@ -673,11 +702,11 @@ func TestRealTLSDialing(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
provider := &Provider{ provider := New(ProviderConfig{
Host: test.host, Host: test.host,
CABundle: test.caBundle, CABundle: test.caBundle,
Dialer: nil, // this test is for the default (production) dialer Dialer: nil, // this test is for the default (production) dialer
} })
conn, err := provider.dial(test.context) conn, err := provider.dial(test.context)
if conn != nil { if conn != nil {
defer conn.Close() defer conn.Close()

View File

@ -43,24 +43,12 @@ func TestLDAPSearch(t *testing.T) {
// Expose the the test LDAP server's TLS port on the localhost. // Expose the the test LDAP server's TLS port on the localhost.
startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace) startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace)
provider := func(editFunc func(p *upstreamldap.Provider)) *upstreamldap.Provider { providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig {
provider := &upstreamldap.Provider{ providerConfig := defaultProviderConfig(env, ldapHostPort)
Name: "test-ldap-provider",
Host: "127.0.0.1:" + ldapHostPort,
CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle),
BindUsername: "cn=admin,dc=pinniped,dc=dev",
BindPassword: "password",
UserSearch: &upstreamldap.UserSearch{
Base: "ou=users,dc=pinniped,dc=dev",
Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case
UsernameAttribute: "cn",
UIDAttribute: "uidNumber",
},
}
if editFunc != nil { if editFunc != nil {
editFunc(provider) editFunc(providerConfig)
} }
return provider return providerConfig
} }
pinnyPassword := env.SupervisorUpstreamLDAP.TestUserPassword pinnyPassword := env.SupervisorUpstreamLDAP.TestUserPassword
@ -78,7 +66,7 @@ func TestLDAPSearch(t *testing.T) {
name: "happy path", name: "happy path",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}},
}, },
@ -87,7 +75,7 @@ func TestLDAPSearch(t *testing.T) {
name: "using a different user search base", name: "using a different user search base",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "dc=pinniped,dc=dev" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "dc=pinniped,dc=dev" })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}},
}, },
@ -96,7 +84,7 @@ func TestLDAPSearch(t *testing.T) {
name: "when the user search filter is already wrapped by parenthesis", name: "when the user search filter is already wrapped by parenthesis",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Filter = "(cn={})" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "(cn={})" })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}},
}, },
@ -105,10 +93,10 @@ func TestLDAPSearch(t *testing.T) {
name: "when the UsernameAttribute is dn and a user search filter is provided", name: "when the UsernameAttribute is dn and a user search filter is provided",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.UsernameAttribute = "dn" p.UserSearch.UsernameAttribute = "dn"
p.UserSearch.Filter = "cn={}" p.UserSearch.Filter = "cn={}"
}), })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "cn=pinny,ou=users,dc=pinniped,dc=dev", UID: "1000", Groups: []string{}}, User: &user.DefaultInfo{Name: "cn=pinny,ou=users,dc=pinniped,dc=dev", UID: "1000", Groups: []string{}},
}, },
@ -117,9 +105,9 @@ func TestLDAPSearch(t *testing.T) {
name: "when the user search filter allows for different ways of logging in and the first one is used", name: "when the user search filter allows for different ways of logging in and the first one is used",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.Filter = "(|(cn={})(mail={}))" p.UserSearch.Filter = "(|(cn={})(mail={}))"
}), })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}},
}, },
@ -128,9 +116,9 @@ func TestLDAPSearch(t *testing.T) {
name: "when the user search filter allows for different ways of logging in and the second one is used", name: "when the user search filter allows for different ways of logging in and the second one is used",
username: "pinny.ldap@example.com", username: "pinny.ldap@example.com",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.Filter = "(|(cn={})(mail={}))" p.UserSearch.Filter = "(|(cn={})(mail={}))"
}), })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}},
}, },
@ -139,7 +127,7 @@ func TestLDAPSearch(t *testing.T) {
name: "when the UIDAttribute is dn", name: "when the UIDAttribute is dn",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "dn" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "dn" })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "cn=pinny,ou=users,dc=pinniped,dc=dev", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "cn=pinny,ou=users,dc=pinniped,dc=dev", Groups: []string{}},
}, },
@ -148,7 +136,7 @@ func TestLDAPSearch(t *testing.T) {
name: "when the UIDAttribute is sn", name: "when the UIDAttribute is sn",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "sn" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "sn" })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "Seal", Groups: []string{}}, User: &user.DefaultInfo{Name: "pinny", UID: "Seal", Groups: []string{}},
}, },
@ -157,7 +145,7 @@ func TestLDAPSearch(t *testing.T) {
name: "when the UsernameAttribute is sn", name: "when the UsernameAttribute is sn",
username: "seAl", // note that this is not case-sensitive! sn=Seal. The server decides which fields are compared case-sensitive. username: "seAl", // note that this is not case-sensitive! sn=Seal. The server decides which fields are compared case-sensitive.
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UsernameAttribute = "sn" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "sn" })),
wantAuthResponse: &authenticator.Response{ wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "Seal", UID: "1000", Groups: []string{}}, // note that the final answer has case preserved from the entry User: &user.DefaultInfo{Name: "Seal", UID: "1000", Groups: []string{}}, // note that the final answer has case preserved from the entry
}, },
@ -166,202 +154,202 @@ func TestLDAPSearch(t *testing.T) {
name: "when the UsernameAttribute is dn and there is no user search filter provided", name: "when the UsernameAttribute is dn and there is no user search filter provided",
username: "cn=pinny,ou=users,dc=pinniped,dc=dev", username: "cn=pinny,ou=users,dc=pinniped,dc=dev",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.UsernameAttribute = "dn" p.UserSearch.UsernameAttribute = "dn"
p.UserSearch.Filter = "" p.UserSearch.Filter = ""
}), })),
wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`,
}, },
{ {
name: "when the bind user username is not a valid DN", name: "when the bind user username is not a valid DN",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.BindUsername = "invalid-dn" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "invalid-dn" })),
wantError: `error binding as "invalid-dn" before user search: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, wantError: `error binding as "invalid-dn" before user search: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`,
}, },
{ {
name: "when the bind user username is wrong", name: "when the bind user username is wrong",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" })),
wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
}, },
{ {
name: "when the bind user password is wrong", name: "when the bind user password is wrong",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.BindPassword = "wrong-password" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })),
wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
}, },
{ {
name: "when the end user password is wrong", name: "when the end user password is wrong",
username: "pinny", username: "pinny",
password: "wrong-pinny-password", password: "wrong-pinny-password",
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
{ {
name: "when the end user password has the wrong case (passwords are compared as case-sensitive)", name: "when the end user password has the wrong case (passwords are compared as case-sensitive)",
username: "pinny", username: "pinny",
password: strings.ToUpper(pinnyPassword), password: strings.ToUpper(pinnyPassword),
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
{ {
name: "when the end user username is wrong", name: "when the end user username is wrong",
username: "wrong-username", username: "wrong-username",
password: pinnyPassword, password: pinnyPassword,
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
{ {
name: "when the user search filter does not compile", name: "when the user search filter does not compile",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Filter = "*" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "*" })),
wantError: `error searching for user "pinny": LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`, wantError: `error searching for user "pinny": LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`,
}, },
{ {
name: "when there are too many search results for the user", name: "when there are too many search results for the user",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.Filter = "objectClass=*" // overly broad search filter p.UserSearch.Filter = "objectClass=*" // overly broad search filter
}), })),
wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `, wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `,
}, },
{ {
name: "when the server is unreachable", name: "when the server is unreachable",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.Host = "127.0.0.1:" + unusedHostPort }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedHostPort })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedHostPort, unusedHostPort), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedHostPort, unusedHostPort),
}, },
{ {
name: "when the server is not parsable", name: "when the server is not parsable",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.Host = "too:many:ports" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
}, },
{ {
name: "when the CA bundle is not parsable", name: "when the CA bundle is not parsable",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.CABundle = []byte("invalid-pem") }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapHostPort), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapHostPort),
}, },
{ {
name: "when the CA bundle does not cause the host to be trusted", name: "when the CA bundle does not cause the host to be trusted",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.CABundle = nil }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapHostPort), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapHostPort),
}, },
{ {
name: "when the UsernameAttribute attribute has multiple values in the entry", name: "when the UsernameAttribute attribute has multiple values in the entry",
username: "wally.ldap@example.com", username: "wally.ldap@example.com",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UsernameAttribute = "mail" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "mail" })),
wantError: `found 2 values for attribute "mail" while searching for user "wally.ldap@example.com", but expected 1 result`, wantError: `found 2 values for attribute "mail" while searching for user "wally.ldap@example.com", but expected 1 result`,
}, },
{ {
name: "when the UIDAttribute attribute has multiple values in the entry", name: "when the UIDAttribute attribute has multiple values in the entry",
username: "wally", username: "wally",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "mail" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "mail" })),
wantError: `found 2 values for attribute "mail" while searching for user "wally", but expected 1 result`, wantError: `found 2 values for attribute "mail" while searching for user "wally", but expected 1 result`,
}, },
{ {
name: "when the UsernameAttribute attribute is not found in the entry", name: "when the UsernameAttribute attribute is not found in the entry",
username: "wally", username: "wally",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.Filter = "cn={}" p.UserSearch.Filter = "cn={}"
p.UserSearch.UsernameAttribute = "attr-does-not-exist" p.UserSearch.UsernameAttribute = "attr-does-not-exist"
}), })),
wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`, wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`,
}, },
{ {
name: "when the UIDAttribute attribute is not found in the entry", name: "when the UIDAttribute attribute is not found in the entry",
username: "wally", username: "wally",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "attr-does-not-exist" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "attr-does-not-exist" })),
wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`, wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`,
}, },
{ {
name: "when the UsernameAttribute has the wrong case", name: "when the UsernameAttribute has the wrong case",
username: "Seal", username: "Seal",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UsernameAttribute = "SN" }), // this is case-sensitive provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "SN" })), // this is case-sensitive
wantError: `found 0 values for attribute "SN" while searching for user "Seal", but expected 1 result`, wantError: `found 0 values for attribute "SN" while searching for user "Seal", but expected 1 result`,
}, },
{ {
name: "when the UIDAttribute has the wrong case", name: "when the UIDAttribute has the wrong case",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "SN" }), // this is case-sensitive provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "SN" })), // this is case-sensitive
wantError: `found 0 values for attribute "SN" while searching for user "pinny", but expected 1 result`, wantError: `found 0 values for attribute "SN" while searching for user "pinny", but expected 1 result`,
}, },
{ {
name: "when the UsernameAttribute is DN and has the wrong case", name: "when the UsernameAttribute is DN and has the wrong case",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.UsernameAttribute = "DN" // dn must be lower-case p.UserSearch.UsernameAttribute = "DN" // dn must be lower-case
p.UserSearch.Filter = "cn={}" p.UserSearch.Filter = "cn={}"
}), })),
wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`, wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`,
}, },
{ {
name: "when the UIDAttribute is DN and has the wrong case", name: "when the UIDAttribute is DN and has the wrong case",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.UIDAttribute = "DN" // dn must be lower-case p.UserSearch.UIDAttribute = "DN" // dn must be lower-case
}), })),
wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`, wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`,
}, },
{ {
name: "when the search base is invalid", name: "when the search base is invalid",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "invalid-base" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "invalid-base" })),
wantError: `error searching for user "pinny": LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, wantError: `error searching for user "pinny": LDAP Result Code 34 "Invalid DN Syntax": invalid DN`,
}, },
{ {
name: "when the search base does not exist", name: "when the search base does not exist",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" })),
wantError: `error searching for user "pinny": LDAP Result Code 32 "No Such Object": `, wantError: `error searching for user "pinny": LDAP Result Code 32 "No Such Object": `,
}, },
{ {
name: "when the search base causes no search results", name: "when the search base causes no search results",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "ou=groups,dc=pinniped,dc=dev" }), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "ou=groups,dc=pinniped,dc=dev" })),
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
{ {
name: "when there is no username specified", name: "when there is no username specified",
username: "", username: "",
password: pinnyPassword, password: pinnyPassword,
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
{ {
name: "when there is no password specified", name: "when there is no password specified",
username: "pinny", username: "pinny",
password: "", password: "",
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantError: `error binding for user "pinny" using provided password against DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 206 "Empty password not allowed by the client": ldap: empty password not allowed by the client`, wantError: `error binding for user "pinny" using provided password against DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 206 "Empty password not allowed by the client": ldap: empty password not allowed by the client`,
}, },
{ {
name: "when the user has no password in their entry", name: "when the user has no password in their entry",
username: "olive", username: "olive",
password: "anything", password: "anything",
provider: provider(nil), provider: upstreamldap.New(*providerConfig(nil)),
wantUnauthenticated: true, wantUnauthenticated: true,
}, },
} }
@ -389,6 +377,75 @@ func TestLDAPSearch(t *testing.T) {
} }
} }
func TestSimultaneousRequestsOnSingleProvider(t *testing.T) {
env := library.IntegrationEnv(t)
// Note that these tests depend on the values hard-coded in the LDIF file in test/deploy/tools/ldap.yaml.
// It requires the test LDAP server from the tools deployment.
if len(env.ToolsNamespace) == 0 {
t.Skip("Skipping test because it requires the test LDAP server in the tools namespace.")
}
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(func() {
cancelFunc() // this will send SIGKILL to the subprocess, just in case
})
ldapHostPort := findRecentlyUnusedLocalhostPorts(t, 1)[0]
// Expose the the test LDAP server's TLS port on the localhost.
startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace)
provider := upstreamldap.New(*defaultProviderConfig(env, ldapHostPort))
// Making multiple simultaneous requests on the same upstreamldap.Provider instance should all succeed
// without triggering the race detector.
iterations := 150
resultCh := make(chan authUserResult, iterations)
for i := 0; i < iterations; i++ {
go func() {
authResponse, authenticated, err := provider.AuthenticateUser(ctx,
env.SupervisorUpstreamLDAP.TestUserCN, env.SupervisorUpstreamLDAP.TestUserPassword,
)
resultCh <- authUserResult{
response: authResponse,
authenticated: authenticated,
err: err,
}
}()
}
for i := 0; i < iterations; i++ {
result := <-resultCh
require.NoError(t, result.err)
require.True(t, result.authenticated, "expected the user to be authenticated, but they were not")
require.Equal(t, &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}},
}, result.response)
}
}
type authUserResult struct {
response *authenticator.Response
authenticated bool
err error
}
func defaultProviderConfig(env *library.TestEnv, ldapHostPort string) *upstreamldap.ProviderConfig {
return &upstreamldap.ProviderConfig{
Name: "test-ldap-provider",
Host: "127.0.0.1:" + ldapHostPort,
CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle),
BindUsername: "cn=admin,dc=pinniped,dc=dev",
BindPassword: "password",
UserSearch: upstreamldap.UserSearchConfig{
Base: "ou=users,dc=pinniped,dc=dev",
Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case
UsernameAttribute: "cn",
UIDAttribute: "uidNumber",
},
}
}
func startKubectlPortForward(ctx context.Context, t *testing.T, hostPort, remotePort, serviceName, namespace string) { func startKubectlPortForward(ctx context.Context, t *testing.T, hostPort, remotePort, serviceName, namespace string) {
t.Helper() t.Helper()
startLongRunningCommandAndWaitForInitialOutput(ctx, t, startLongRunningCommandAndWaitForInitialOutput(ctx, t,