From e6e6497022b0661c04b309da03f0bc9234df4ed7 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 15 Apr 2021 10:25:35 -0700 Subject: [PATCH] Introduce upstreamldap.New to prevent changes to the underlying config Makes it easier to support using the same upstreamldap.Provider from multiple goroutines safely. --- .../upstreamwatcher/ldap_upstream_watcher.go | 14 +- .../ldap_upstream_watcher_test.go | 32 +-- internal/upstreamldap/upstreamldap.go | 68 +++--- internal/upstreamldap/upstreamldap_test.go | 199 ++++++++++-------- test/integration/ldap_client_test.go | 177 ++++++++++------ 5 files changed, 296 insertions(+), 194 deletions(-) diff --git a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go index 1466ffd3..c119ac57 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go @@ -107,10 +107,10 @@ func (c *ldapWatcherController) Sync(ctx controllerlib.Context) error { func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.LDAPIdentityProvider) provider.UpstreamLDAPIdentityProviderI { spec := upstream.Spec - result := &upstreamldap.Provider{ + config := &upstreamldap.ProviderConfig{ Name: upstream.Name, Host: spec.Host, - UserSearch: &upstreamldap.UserSearch{ + UserSearch: upstreamldap.UserSearchConfig{ Base: spec.UserSearch.Base, Filter: spec.UserSearch.Filter, UsernameAttribute: spec.UserSearch.Attributes.Username, @@ -119,17 +119,17 @@ func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream * Dialer: c.ldapDialer, } conditions := []*v1alpha1.Condition{ - c.validateSecret(upstream, result), - c.validateTLSConfig(upstream, result), + c.validateSecret(upstream, config), + c.validateTLSConfig(upstream, config), } hadErrorCondition := c.updateStatus(ctx, upstream, conditions) if hadErrorCondition { return nil } - return result + return upstreamldap.New(*config) } -func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.Provider) *v1alpha1.Condition { +func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.ProviderConfig) *v1alpha1.Condition { tlsSpec := upstream.Spec.TLS if tlsSpec == nil { return c.validTLSCondition(noTLSConfigurationMessage) @@ -171,7 +171,7 @@ func (c *ldapWatcherController) invalidTLSCondition(message string) *v1alpha1.Co } } -func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.Provider) *v1alpha1.Condition { +func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.ProviderConfig) *v1alpha1.Condition { secretName := upstream.Spec.Bind.SecretName secret, err := c.secretInformer.Lister().Secrets(upstream.Namespace).Get(secretName) diff --git a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go index 571d3bd9..ca362db9 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go @@ -194,13 +194,13 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { return deepCopy } - providerForValidUpstream := &upstreamldap.Provider{ + providerConfigForValidUpstream := &upstreamldap.ProviderConfig{ Name: testName, Host: testHost, CABundle: testCABundle, BindUsername: testBindUsername, BindPassword: testBindPassword, - UserSearch: &upstreamldap.UserSearch{ + UserSearch: upstreamldap.UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, UsernameAttribute: testUsernameAttrName, @@ -215,7 +215,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { inputSecrets []runtime.Object ldapDialer upstreamldap.LDAPDialer wantErr string - wantResultingCache []*upstreamldap.Provider + wantResultingCache []*upstreamldap.ProviderConfig wantResultingUpstreams []v1alpha1.LDAPIdentityProvider }{ { @@ -230,7 +230,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, - wantResultingCache: []*upstreamldap.Provider{providerForValidUpstream}, + wantResultingCache: []*upstreamldap.ProviderConfig{providerConfigForValidUpstream}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.LDAPIdentityProviderStatus{ @@ -262,7 +262,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { inputUpstreams: []runtime.Object{validUpstream}, inputSecrets: []runtime.Object{}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), - wantResultingCache: []*upstreamldap.Provider{}, + wantResultingCache: []*upstreamldap.ProviderConfig{}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.LDAPIdentityProviderStatus{ @@ -298,7 +298,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Data: testValidSecretData, }}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), - wantResultingCache: []*upstreamldap.Provider{}, + wantResultingCache: []*upstreamldap.ProviderConfig{}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.LDAPIdentityProviderStatus{ @@ -333,7 +333,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, }}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), - wantResultingCache: []*upstreamldap.Provider{}, + wantResultingCache: []*upstreamldap.ProviderConfig{}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.LDAPIdentityProviderStatus{ @@ -371,7 +371,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Data: testValidSecretData, }}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), - wantResultingCache: []*upstreamldap.Provider{}, + wantResultingCache: []*upstreamldap.ProviderConfig{}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.LDAPIdentityProviderStatus{ @@ -409,7 +409,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Data: testValidSecretData, }}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), - wantResultingCache: []*upstreamldap.Provider{}, + wantResultingCache: []*upstreamldap.ProviderConfig{}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.LDAPIdentityProviderStatus{ @@ -446,14 +446,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, - wantResultingCache: []*upstreamldap.Provider{ + wantResultingCache: []*upstreamldap.ProviderConfig{ { Name: testName, Host: testHost, CABundle: nil, BindUsername: testBindUsername, BindPassword: testBindPassword, - UserSearch: &upstreamldap.UserSearch{ + UserSearch: upstreamldap.UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, UsernameAttribute: testUsernameAttrName, @@ -498,14 +498,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, - wantResultingCache: []*upstreamldap.Provider{ + wantResultingCache: []*upstreamldap.ProviderConfig{ { Name: testName, Host: testHost, CABundle: nil, BindUsername: testBindUsername, BindPassword: testBindPassword, - UserSearch: &upstreamldap.UserSearch{ + UserSearch: upstreamldap.UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, UsernameAttribute: testUsernameAttrName, @@ -553,7 +553,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Data: testValidSecretData, }}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), - wantResultingCache: []*upstreamldap.Provider{providerForValidUpstream}, + wantResultingCache: []*upstreamldap.ProviderConfig{providerConfigForValidUpstream}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{ { ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: "other-upstream", Generation: 42}, @@ -616,7 +616,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) cache := provider.NewDynamicUpstreamIDPProvider() cache.SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{ - &upstreamldap.Provider{Name: "initial-entry"}, + upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}), }) controller := NewLDAPUpstreamWatcherController( @@ -647,7 +647,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) for i := range actualIDPList { actualIDP := actualIDPList[i].(*upstreamldap.Provider) - require.Equal(t, tt.wantResultingCache[i], actualIDP) + require.Equal(t, *tt.wantResultingCache[i], actualIDP.GetConfig()) } actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{}) diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 765b5fa9..f27551c7 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -50,9 +50,10 @@ func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, err return f(ctx, hostAndPort) } -// Provider includes all of the settings for connection and searching for users and groups in +// ProviderConfig includes all of the settings for connection and searching for users and groups in // the upstream LDAP IDP. It also provides methods for testing the connection and performing logins. -type Provider struct { +// The nested structs are not pointer fields to enable deep copy on function params and return values. +type ProviderConfig struct { // Name is the unique name of this upstream LDAP IDP. Name string @@ -70,14 +71,14 @@ type Provider struct { BindPassword string // UserSearch contains information about how to search for users in the upstream LDAP IDP. - UserSearch *UserSearch + UserSearch UserSearchConfig // Dialer exists to enable testing. When nil, will use a default appropriate for production use. Dialer LDAPDialer } -// UserSearch contains information about how to search for users in the upstream LDAP IDP. -type UserSearch struct { +// UserSearchConfig contains information about how to search for users in the upstream LDAP IDP. +type UserSearchConfig struct { // Base is the base DN to use for the user search in the upstream LDAP IDP. Base string @@ -93,13 +94,28 @@ type UserSearch struct { UIDAttribute string } +type Provider struct { + c ProviderConfig +} + +// Create a Provider. The config is not a pointer to ensure that a copy of the config is created, +// making the resulting Provider use an effectively read-only configuration. +func New(config ProviderConfig) *Provider { + return &Provider{c: config} +} + +// A reader for the config. Returns a copy of the config to keep the underlying config read-only. +func (p *Provider) GetConfig() ProviderConfig { + return p.c +} + func (p *Provider) dial(ctx context.Context) (Conn, error) { - hostAndPort, err := hostAndPortWithDefaultPort(p.Host, ldap.DefaultLdapsPort) + hostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } - if p.Dialer != nil { - return p.Dialer.Dial(ctx, hostAndPort) + if p.c.Dialer != nil { + return p.c.Dialer.Dial(ctx, hostAndPort) } return p.dialTLS(ctx, hostAndPort) } @@ -109,8 +125,8 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) { // so we implement it ourselves, heavily inspired by ldap.DialURL. func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { rootCAs := x509.NewCertPool() - if p.CABundle != nil { - if !rootCAs.AppendCertsFromPEM(p.CABundle) { + if p.c.CABundle != nil { + if !rootCAs.AppendCertsFromPEM(p.c.CABundle) { return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle")) } } @@ -154,14 +170,14 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, // A name for this upstream provider. func (p *Provider) GetName() string { - return p.Name + return p.c.Name } // Return a URL which uniquely identifies this LDAP provider, e.g. "ldaps://host.example.com:1234". // This URL is not used for connecting to the provider, but rather is used for creating a globally unique user // identifier by being combined with the user's UID, since user UIDs are only unique within one provider. func (p *Provider) GetURL() string { - return fmt.Sprintf("%s://%s", ldapsScheme, p.Host) + return fmt.Sprintf("%s://%s", ldapsScheme, p.c.Host) } // TestConnection provides a method for testing the connection and bind settings. It performs a dial and bind @@ -183,7 +199,7 @@ func (p *Provider) TestAuthenticateUser(ctx context.Context, testUsername string // Authenticate a user and return their mapped username, groups, and UID. Implements authenticators.UserAuthenticator. func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) { - if p.UserSearch.UsernameAttribute == distinguishedNameAttributeName && len(p.UserSearch.Filter) == 0 { + if p.c.UserSearch.UsernameAttribute == distinguishedNameAttributeName && len(p.c.UserSearch.Filter) == 0 { // LDAP search filters do not allow searching by DN. return nil, false, fmt.Errorf(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`) } @@ -195,13 +211,13 @@ func (p *Provider) AuthenticateUser(ctx context.Context, username, password stri conn, err := p.dial(ctx) if err != nil { - return nil, false, fmt.Errorf(`error dialing host "%s": %w`, p.Host, err) + return nil, false, fmt.Errorf(`error dialing host "%s": %w`, p.c.Host, err) } defer conn.Close() - err = conn.Bind(p.BindUsername, p.BindPassword) + err = conn.Bind(p.c.BindUsername, p.c.BindPassword) if err != nil { - return nil, false, fmt.Errorf(`error binding as "%s" before user search: %w`, p.BindUsername, err) + return nil, false, fmt.Errorf(`error binding as "%s" before user search: %w`, p.c.BindUsername, err) } mappedUsername, mappedUID, err := p.searchAndBindUser(conn, username, password) @@ -243,12 +259,12 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, password string return "", "", fmt.Errorf(`searching for user "%s" resulted in search result without DN`, username) } - mappedUsername, err := p.getSearchResultAttributeValue(p.UserSearch.UsernameAttribute, userEntry, username) + mappedUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, username) if err != nil { return "", "", err } - mappedUID, err := p.getSearchResultAttributeValue(p.UserSearch.UIDAttribute, userEntry, username) + mappedUID, err := p.getSearchResultAttributeValue(p.c.UserSearch.UIDAttribute, userEntry, username) if err != nil { return "", "", err } @@ -270,7 +286,7 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, password string func (p *Provider) userSearchRequest(username string) *ldap.SearchRequest { // See https://ldap.com/the-ldap-search-operation for general documentation of LDAP search options. return &ldap.SearchRequest{ - BaseDN: p.UserSearch.Base, + BaseDN: p.c.UserSearch.Base, Scope: ldap.ScopeWholeSubtree, DerefAliases: ldap.DerefAlways, // TODO what's the best value here? SizeLimit: 2, @@ -284,21 +300,21 @@ func (p *Provider) userSearchRequest(username string) *ldap.SearchRequest { func (p *Provider) userSearchRequestedAttributes() []string { attributes := []string{} - if p.UserSearch.UsernameAttribute != distinguishedNameAttributeName { - attributes = append(attributes, p.UserSearch.UsernameAttribute) + if p.c.UserSearch.UsernameAttribute != distinguishedNameAttributeName { + attributes = append(attributes, p.c.UserSearch.UsernameAttribute) } - if p.UserSearch.UIDAttribute != distinguishedNameAttributeName { - attributes = append(attributes, p.UserSearch.UIDAttribute) + if p.c.UserSearch.UIDAttribute != distinguishedNameAttributeName { + attributes = append(attributes, p.c.UserSearch.UIDAttribute) } return attributes } func (p *Provider) userSearchFilter(username string) string { safeUsername := p.escapeUsernameForSearchFilter(username) - if len(p.UserSearch.Filter) == 0 { - return fmt.Sprintf("(%s=%s)", p.UserSearch.UsernameAttribute, safeUsername) + if len(p.c.UserSearch.Filter) == 0 { + return fmt.Sprintf("(%s=%s)", p.c.UserSearch.UsernameAttribute, safeUsername) } - filter := strings.ReplaceAll(p.UserSearch.Filter, userSearchFilterInterpolationLocationMarker, safeUsername) + filter := strings.ReplaceAll(p.c.UserSearch.Filter, userSearchFilterInterpolationLocationMarker, safeUsername) if strings.HasPrefix(filter, "(") && strings.HasSuffix(filter, ")") { return filter } diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 2615f084..b6b21846 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -43,14 +43,14 @@ var ( ) func TestAuthenticateUser(t *testing.T) { - provider := func(editFunc func(p *Provider)) *Provider { - provider := &Provider{ + providerConfig := func(editFunc func(p *ProviderConfig)) *ProviderConfig { + config := &ProviderConfig{ Name: "some-provider-name", Host: testHost, CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test BindUsername: testBindUsername, BindPassword: testBindPassword, - UserSearch: &UserSearch{ + UserSearch: UserSearchConfig{ Base: testUserSearchBase, Filter: testUserSearchFilter, UsernameAttribute: testUserSearchUsernameAttribute, @@ -58,9 +58,9 @@ func TestAuthenticateUser(t *testing.T) { }, } if editFunc != nil { - editFunc(provider) + editFunc(config) } - return provider + return config } expectedSearch := func(editFunc func(r *ldap.SearchRequest)) *ldap.SearchRequest { @@ -85,7 +85,7 @@ func TestAuthenticateUser(t *testing.T) { name string username string password string - provider *Provider + providerConfig *ProviderConfig setupMocks func(conn *mockldapconn.MockConn) dialError error wantError string @@ -94,10 +94,10 @@ func TestAuthenticateUser(t *testing.T) { wantUnauthenticated bool }{ { - name: "happy path", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "happy path", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -128,7 +128,7 @@ func TestAuthenticateUser(t *testing.T) { name: "when the user search filter is already wrapped by parenthesis then it is not wrapped again", username: testUpstreamUsername, password: testUpstreamPassword, - provider: provider(func(p *Provider) { + providerConfig: providerConfig(func(p *ProviderConfig) { p.UserSearch.Filter = "(" + testUserSearchFilter + ")" }), setupMocks: func(conn *mockldapconn.MockConn) { @@ -159,7 +159,7 @@ func TestAuthenticateUser(t *testing.T) { name: "when the UsernameAttribute is dn and there is a user search filter provided", username: testUpstreamUsername, password: testUpstreamPassword, - provider: provider(func(p *Provider) { + providerConfig: providerConfig(func(p *ProviderConfig) { p.UserSearch.UsernameAttribute = "dn" }), setupMocks: func(conn *mockldapconn.MockConn) { @@ -191,7 +191,7 @@ func TestAuthenticateUser(t *testing.T) { name: "when the UIDAttribute is dn", username: testUpstreamUsername, password: testUpstreamPassword, - provider: provider(func(p *Provider) { + providerConfig: providerConfig(func(p *ProviderConfig) { p.UserSearch.UIDAttribute = "dn" }), setupMocks: func(conn *mockldapconn.MockConn) { @@ -223,7 +223,7 @@ func TestAuthenticateUser(t *testing.T) { name: "when Filter is blank it derives a search filter from the UsernameAttribute", username: testUpstreamUsername, password: testUpstreamPassword, - provider: provider(func(p *Provider) { + providerConfig: providerConfig(func(p *ProviderConfig) { p.UserSearch.Filter = "" }), setupMocks: func(conn *mockldapconn.MockConn) { @@ -253,10 +253,10 @@ func TestAuthenticateUser(t *testing.T) { }, }, { - name: "when the username has special LDAP search filter characters then they must be properly escaped in the search filter", - username: `a&b|c(d)e\f*g`, - password: testUpstreamPassword, - provider: provider(nil), + name: "when the username has special LDAP search filter characters then they must be properly escaped in the search filter", + username: `a&b|c(d)e\f*g`, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) { @@ -284,18 +284,18 @@ func TestAuthenticateUser(t *testing.T) { }, }, { - name: "when dial fails", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), - dialError: errors.New("some dial error"), - wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), + name: "when dial fails", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), + dialError: errors.New("some dial error"), + wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), }, { name: "when the UsernameAttribute is dn and there is not a user search filter provided", username: testUpstreamUsername, password: testUpstreamPassword, - provider: provider(func(p *Provider) { + providerConfig: providerConfig(func(p *ProviderConfig) { p.UserSearch.UsernameAttribute = "dn" p.UserSearch.Filter = "" }), @@ -303,10 +303,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, }, { - name: "when binding as the bind user returns an error", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when binding as the bind user returns an error", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) conn.EXPECT().Close().Times(1) @@ -314,10 +314,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`error binding as "%s" before user search: some bind error`, testBindUsername), }, { - name: "when searching for the user returns an error", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns an error", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(nil, errors.New("some search error")).Times(1) @@ -326,10 +326,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`error searching for user "%s": some search error`, testUpstreamUsername), }, { - name: "when searching for the user returns no results", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns no results", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -340,10 +340,10 @@ func TestAuthenticateUser(t *testing.T) { wantUnauthenticated: true, }, { - name: "when searching for the user returns multiple results", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns multiple results", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -357,10 +357,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`searching for user "%s" resulted in 2 search results, but expected 1 result`, testUpstreamUsername), }, { - name: "when searching for the user returns a user without a DN", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user without a DN", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -373,10 +373,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`searching for user "%s" resulted in search result without DN`, testUpstreamUsername), }, { - name: "when searching for the user returns a user without an expected username attribute", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user without an expected username attribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -394,10 +394,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUsernameAttribute, testUpstreamUsername), }, { - name: "when searching for the user returns a user with too many values for the expected username attribute", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user with too many values for the expected username attribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -419,10 +419,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUsernameAttribute, testUpstreamUsername), }, { - name: "when searching for the user returns a user with an empty value for the expected username attribute", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user with an empty value for the expected username attribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -441,10 +441,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUsernameAttribute, testUpstreamUsername), }, { - name: "when searching for the user returns a user without an expected UID attribute", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user without an expected UID attribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -462,10 +462,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername), }, { - name: "when searching for the user returns a user with too many values for the expected UID attribute", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user with too many values for the expected UID attribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -487,10 +487,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername), }, { - name: "when searching for the user returns a user with an empty value for the expected UID attribute", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when searching for the user returns a user with an empty value for the expected UID attribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -509,10 +509,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername), }, { - name: "when binding as the found user returns an error", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when binding as the found user returns an error", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -532,10 +532,10 @@ func TestAuthenticateUser(t *testing.T) { wantError: fmt.Sprintf(`error binding for user "%s" using provided password against DN "%s": some bind error`, testUpstreamUsername, testSearchResultDNValue), }, { - name: "when binding as the found user returns a specific invalid credentials error", - username: testUpstreamUsername, - password: testUpstreamPassword, - provider: provider(nil), + name: "when binding as the found user returns a specific invalid credentials error", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ @@ -558,7 +558,7 @@ func TestAuthenticateUser(t *testing.T) { name: "when no username is specified", username: "", password: testUpstreamPassword, - provider: provider(nil), + providerConfig: providerConfig(nil), wantToSkipDial: true, wantUnauthenticated: true, }, @@ -576,16 +576,17 @@ func TestAuthenticateUser(t *testing.T) { } dialWasAttempted := false - tt.provider.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { dialWasAttempted = true - require.Equal(t, tt.provider.Host, hostAndPort) + require.Equal(t, tt.providerConfig.Host, hostAndPort) if tt.dialError != nil { return nil, tt.dialError } return conn, nil }) - authResponse, authenticated, err := tt.provider.AuthenticateUser(context.Background(), tt.username, tt.password) + provider := New(*tt.providerConfig) + authResponse, authenticated, err := provider.AuthenticateUser(context.Background(), tt.username, tt.password) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) @@ -607,9 +608,37 @@ func TestAuthenticateUser(t *testing.T) { } } +func TestGetConfig(t *testing.T) { + c := ProviderConfig{ + Name: "original-provider-name", + Host: testHost, + CABundle: []byte("some-ca-bundle"), + BindUsername: testBindUsername, + BindPassword: testBindPassword, + UserSearch: UserSearchConfig{ + Base: testUserSearchBase, + Filter: testUserSearchFilter, + UsernameAttribute: testUserSearchUsernameAttribute, + UIDAttribute: testUserSearchUIDAttribute, + }, + } + p := New(c) + require.Equal(t, c, p.c) + require.Equal(t, c, p.GetConfig()) + + // The original config can be changed without impacting the provider, since the provider made a copy of the config. + c.Name = "changed-name" + require.Equal(t, "original-provider-name", p.c.Name) + + // The return value of GetConfig can be modified without impacting the provider, since it is a copy of the config. + returnedConfig := p.GetConfig() + returnedConfig.Name = "changed-name" + require.Equal(t, "original-provider-name", p.c.Name) +} + func TestGetURL(t *testing.T) { - require.Equal(t, "ldaps://ldap.example.com:1234", (&Provider{Host: "ldap.example.com:1234"}).GetURL()) - require.Equal(t, "ldaps://ldap.example.com", (&Provider{Host: "ldap.example.com"}).GetURL()) + require.Equal(t, "ldaps://ldap.example.com:1234", New(ProviderConfig{Host: "ldap.example.com:1234"}).GetURL()) + require.Equal(t, "ldaps://ldap.example.com", New(ProviderConfig{Host: "ldap.example.com"}).GetURL()) } // Testing of host parsing, TLS negotiation, and CA bundle, etc. for the production code's dialer. @@ -673,11 +702,11 @@ func TestRealTLSDialing(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - provider := &Provider{ + provider := New(ProviderConfig{ Host: test.host, CABundle: test.caBundle, Dialer: nil, // this test is for the default (production) dialer - } + }) conn, err := provider.dial(test.context) if conn != nil { defer conn.Close() diff --git a/test/integration/ldap_client_test.go b/test/integration/ldap_client_test.go index 3a201834..fb49a099 100644 --- a/test/integration/ldap_client_test.go +++ b/test/integration/ldap_client_test.go @@ -43,24 +43,12 @@ func TestLDAPSearch(t *testing.T) { // Expose the the test LDAP server's TLS port on the localhost. startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace) - provider := func(editFunc func(p *upstreamldap.Provider)) *upstreamldap.Provider { - provider := &upstreamldap.Provider{ - Name: "test-ldap-provider", - Host: "127.0.0.1:" + ldapHostPort, - CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle), - BindUsername: "cn=admin,dc=pinniped,dc=dev", - BindPassword: "password", - UserSearch: &upstreamldap.UserSearch{ - Base: "ou=users,dc=pinniped,dc=dev", - Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case - UsernameAttribute: "cn", - UIDAttribute: "uidNumber", - }, - } + providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig { + providerConfig := defaultProviderConfig(env, ldapHostPort) if editFunc != nil { - editFunc(provider) + editFunc(providerConfig) } - return provider + return providerConfig } pinnyPassword := env.SupervisorUpstreamLDAP.TestUserPassword @@ -78,7 +66,7 @@ func TestLDAPSearch(t *testing.T) { name: "happy path", username: "pinny", password: pinnyPassword, - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, }, @@ -87,7 +75,7 @@ func TestLDAPSearch(t *testing.T) { name: "using a different user search base", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "dc=pinniped,dc=dev" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "dc=pinniped,dc=dev" })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, }, @@ -96,7 +84,7 @@ func TestLDAPSearch(t *testing.T) { name: "when the user search filter is already wrapped by parenthesis", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Filter = "(cn={})" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "(cn={})" })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, }, @@ -105,10 +93,10 @@ func TestLDAPSearch(t *testing.T) { name: "when the UsernameAttribute is dn and a user search filter is provided", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "dn" p.UserSearch.Filter = "cn={}" - }), + })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "cn=pinny,ou=users,dc=pinniped,dc=dev", UID: "1000", Groups: []string{}}, }, @@ -117,9 +105,9 @@ func TestLDAPSearch(t *testing.T) { name: "when the user search filter allows for different ways of logging in and the first one is used", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "(|(cn={})(mail={}))" - }), + })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, }, @@ -128,9 +116,9 @@ func TestLDAPSearch(t *testing.T) { name: "when the user search filter allows for different ways of logging in and the second one is used", username: "pinny.ldap@example.com", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "(|(cn={})(mail={}))" - }), + })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, }, @@ -139,7 +127,7 @@ func TestLDAPSearch(t *testing.T) { name: "when the UIDAttribute is dn", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "dn" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "dn" })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "cn=pinny,ou=users,dc=pinniped,dc=dev", Groups: []string{}}, }, @@ -148,7 +136,7 @@ func TestLDAPSearch(t *testing.T) { name: "when the UIDAttribute is sn", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "sn" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "sn" })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "pinny", UID: "Seal", Groups: []string{}}, }, @@ -157,7 +145,7 @@ func TestLDAPSearch(t *testing.T) { name: "when the UsernameAttribute is sn", username: "seAl", // note that this is not case-sensitive! sn=Seal. The server decides which fields are compared case-sensitive. password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UsernameAttribute = "sn" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "sn" })), wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{Name: "Seal", UID: "1000", Groups: []string{}}, // note that the final answer has case preserved from the entry }, @@ -166,202 +154,202 @@ func TestLDAPSearch(t *testing.T) { name: "when the UsernameAttribute is dn and there is no user search filter provided", username: "cn=pinny,ou=users,dc=pinniped,dc=dev", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "dn" p.UserSearch.Filter = "" - }), + })), wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, }, { name: "when the bind user username is not a valid DN", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.BindUsername = "invalid-dn" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "invalid-dn" })), wantError: `error binding as "invalid-dn" before user search: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, }, { name: "when the bind user username is wrong", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" })), wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, }, { name: "when the bind user password is wrong", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.BindPassword = "wrong-password" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })), wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, }, { name: "when the end user password is wrong", username: "pinny", password: "wrong-pinny-password", - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantUnauthenticated: true, }, { name: "when the end user password has the wrong case (passwords are compared as case-sensitive)", username: "pinny", password: strings.ToUpper(pinnyPassword), - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantUnauthenticated: true, }, { name: "when the end user username is wrong", username: "wrong-username", password: pinnyPassword, - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantUnauthenticated: true, }, { name: "when the user search filter does not compile", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Filter = "*" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "*" })), wantError: `error searching for user "pinny": LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`, }, { name: "when there are too many search results for the user", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "objectClass=*" // overly broad search filter - }), + })), wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `, }, { name: "when the server is unreachable", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.Host = "127.0.0.1:" + unusedHostPort }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedHostPort })), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedHostPort, unusedHostPort), }, { name: "when the server is not parsable", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.Host = "too:many:ports" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })), wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, }, { name: "when the CA bundle is not parsable", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.CABundle = []byte("invalid-pem") }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapHostPort), }, { name: "when the CA bundle does not cause the host to be trusted", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.CABundle = nil }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapHostPort), }, { name: "when the UsernameAttribute attribute has multiple values in the entry", username: "wally.ldap@example.com", password: "unused-because-error-is-before-bind", - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UsernameAttribute = "mail" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "mail" })), wantError: `found 2 values for attribute "mail" while searching for user "wally.ldap@example.com", but expected 1 result`, }, { name: "when the UIDAttribute attribute has multiple values in the entry", username: "wally", password: "unused-because-error-is-before-bind", - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "mail" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "mail" })), wantError: `found 2 values for attribute "mail" while searching for user "wally", but expected 1 result`, }, { name: "when the UsernameAttribute attribute is not found in the entry", username: "wally", password: "unused-because-error-is-before-bind", - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "cn={}" p.UserSearch.UsernameAttribute = "attr-does-not-exist" - }), + })), wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`, }, { name: "when the UIDAttribute attribute is not found in the entry", username: "wally", password: "unused-because-error-is-before-bind", - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "attr-does-not-exist" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "attr-does-not-exist" })), wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`, }, { name: "when the UsernameAttribute has the wrong case", username: "Seal", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UsernameAttribute = "SN" }), // this is case-sensitive + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "SN" })), // this is case-sensitive wantError: `found 0 values for attribute "SN" while searching for user "Seal", but expected 1 result`, }, { name: "when the UIDAttribute has the wrong case", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.UIDAttribute = "SN" }), // this is case-sensitive + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "SN" })), // this is case-sensitive wantError: `found 0 values for attribute "SN" while searching for user "pinny", but expected 1 result`, }, { name: "when the UsernameAttribute is DN and has the wrong case", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "DN" // dn must be lower-case p.UserSearch.Filter = "cn={}" - }), + })), wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`, }, { name: "when the UIDAttribute is DN and has the wrong case", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "DN" // dn must be lower-case - }), + })), wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`, }, { name: "when the search base is invalid", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "invalid-base" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "invalid-base" })), wantError: `error searching for user "pinny": LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, }, { name: "when the search base does not exist", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" })), wantError: `error searching for user "pinny": LDAP Result Code 32 "No Such Object": `, }, { name: "when the search base causes no search results", username: "pinny", password: pinnyPassword, - provider: provider(func(p *upstreamldap.Provider) { p.UserSearch.Base = "ou=groups,dc=pinniped,dc=dev" }), + provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "ou=groups,dc=pinniped,dc=dev" })), wantUnauthenticated: true, }, { name: "when there is no username specified", username: "", password: pinnyPassword, - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantUnauthenticated: true, }, { name: "when there is no password specified", username: "pinny", password: "", - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantError: `error binding for user "pinny" using provided password against DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 206 "Empty password not allowed by the client": ldap: empty password not allowed by the client`, }, { name: "when the user has no password in their entry", username: "olive", password: "anything", - provider: provider(nil), + provider: upstreamldap.New(*providerConfig(nil)), wantUnauthenticated: true, }, } @@ -389,6 +377,75 @@ func TestLDAPSearch(t *testing.T) { } } +func TestSimultaneousRequestsOnSingleProvider(t *testing.T) { + env := library.IntegrationEnv(t) + + // Note that these tests depend on the values hard-coded in the LDIF file in test/deploy/tools/ldap.yaml. + // It requires the test LDAP server from the tools deployment. + if len(env.ToolsNamespace) == 0 { + t.Skip("Skipping test because it requires the test LDAP server in the tools namespace.") + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(func() { + cancelFunc() // this will send SIGKILL to the subprocess, just in case + }) + + ldapHostPort := findRecentlyUnusedLocalhostPorts(t, 1)[0] + + // Expose the the test LDAP server's TLS port on the localhost. + startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace) + + provider := upstreamldap.New(*defaultProviderConfig(env, ldapHostPort)) + + // Making multiple simultaneous requests on the same upstreamldap.Provider instance should all succeed + // without triggering the race detector. + iterations := 150 + resultCh := make(chan authUserResult, iterations) + for i := 0; i < iterations; i++ { + go func() { + authResponse, authenticated, err := provider.AuthenticateUser(ctx, + env.SupervisorUpstreamLDAP.TestUserCN, env.SupervisorUpstreamLDAP.TestUserPassword, + ) + resultCh <- authUserResult{ + response: authResponse, + authenticated: authenticated, + err: err, + } + }() + } + for i := 0; i < iterations; i++ { + result := <-resultCh + require.NoError(t, result.err) + require.True(t, result.authenticated, "expected the user to be authenticated, but they were not") + require.Equal(t, &authenticator.Response{ + User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{}}, + }, result.response) + } +} + +type authUserResult struct { + response *authenticator.Response + authenticated bool + err error +} + +func defaultProviderConfig(env *library.TestEnv, ldapHostPort string) *upstreamldap.ProviderConfig { + return &upstreamldap.ProviderConfig{ + Name: "test-ldap-provider", + Host: "127.0.0.1:" + ldapHostPort, + CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle), + BindUsername: "cn=admin,dc=pinniped,dc=dev", + BindPassword: "password", + UserSearch: upstreamldap.UserSearchConfig{ + Base: "ou=users,dc=pinniped,dc=dev", + Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case + UsernameAttribute: "cn", + UIDAttribute: "uidNumber", + }, + } +} + func startKubectlPortForward(ctx context.Context, t *testing.T, hostPort, remotePort, serviceName, namespace string) { t.Helper() startLongRunningCommandAndWaitForInitialOutput(ctx, t,