From b9ce84fd68c980fad10e96b016dfb6b37df01c15 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 15 Apr 2021 14:44:43 -0700 Subject: [PATCH] Test the LDAP config by connecting to the server in the controller --- .../upstreamwatcher/ldap_upstream_watcher.go | 59 ++++-- .../ldap_upstream_watcher_test.go | 171 ++++++++++++++---- internal/upstreamldap/upstreamldap.go | 36 +++- internal/upstreamldap/upstreamldap_test.go | 95 ++++++++++ 4 files changed, 307 insertions(+), 54 deletions(-) diff --git a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go index c119ac57..353267c9 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "encoding/base64" "fmt" + "time" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/equality" @@ -31,7 +32,9 @@ const ( // Constants related to conditions. typeBindSecretValid = "BindSecretValid" - tlsConfigurationValid = "TLSConfigurationValid" + typeTLSConfigurationValid = "TLSConfigurationValid" + typeLDAPConnectionValid = "LDAPConnectionValid" + reasonLDAPConnectionError = "LDAPConnectionError" noTLSConfigurationMessage = "no TLS configuration provided" loadedTLSConfigurationMessage = "loaded TLS configuration" ) @@ -118,18 +121,26 @@ func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream * }, Dialer: c.ldapDialer, } - conditions := []*v1alpha1.Condition{ - c.validateSecret(upstream, config), - c.validateTLSConfig(upstream, config), + + conditions := []*v1alpha1.Condition{} + secretValidCondition := c.validateSecret(upstream, config) + tlsValidCondition := c.validateTLSConfig(upstream, config) + conditions = append(conditions, secretValidCondition, tlsValidCondition) + + // No point in trying to connect to the server if the config was already determined to be invalid. + if secretValidCondition.Status == v1alpha1.ConditionTrue && tlsValidCondition.Status == v1alpha1.ConditionTrue { + conditions = append(conditions, c.validateFinishedConfig(ctx, config)) } + hadErrorCondition := c.updateStatus(ctx, upstream, conditions) if hadErrorCondition { return nil } + return upstreamldap.New(*config) } -func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.ProviderConfig) *v1alpha1.Condition { +func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentityProvider, config *upstreamldap.ProviderConfig) *v1alpha1.Condition { tlsSpec := upstream.Spec.TLS if tlsSpec == nil { return c.validTLSCondition(noTLSConfigurationMessage) @@ -149,13 +160,37 @@ func (c *ldapWatcherController) validateTLSConfig(upstream *v1alpha1.LDAPIdentit return c.invalidTLSCondition(fmt.Sprintf("certificateAuthorityData is invalid: %s", errNoCertificates)) } - result.CABundle = bundle + config.CABundle = bundle return c.validTLSCondition(loadedTLSConfigurationMessage) } +func (c *ldapWatcherController) validateFinishedConfig(ctx context.Context, config *upstreamldap.ProviderConfig) *v1alpha1.Condition { + ldapProvider := upstreamldap.New(*config) + + testConnectionTimeout, cancelFunc := context.WithTimeout(ctx, 60*time.Second) + defer cancelFunc() + + err := ldapProvider.TestConnection(testConnectionTimeout) + if err != nil { + return &v1alpha1.Condition{ + Type: typeLDAPConnectionValid, + Status: v1alpha1.ConditionFalse, + Reason: reasonLDAPConnectionError, + Message: fmt.Sprintf(`could not successfully connect to "%s" and bind as user "%s: %s`, config.Host, config.BindUsername, err.Error()), + } + } + + return &v1alpha1.Condition{ + Type: typeLDAPConnectionValid, + Status: v1alpha1.ConditionTrue, + Reason: reasonSuccess, + Message: fmt.Sprintf(`successfully able to connect to "%s" and bind as user "%s"`, config.Host, config.BindUsername), + } +} + func (c *ldapWatcherController) validTLSCondition(message string) *v1alpha1.Condition { return &v1alpha1.Condition{ - Type: tlsConfigurationValid, + Type: typeTLSConfigurationValid, Status: v1alpha1.ConditionTrue, Reason: reasonSuccess, Message: message, @@ -164,14 +199,14 @@ func (c *ldapWatcherController) validTLSCondition(message string) *v1alpha1.Cond func (c *ldapWatcherController) invalidTLSCondition(message string) *v1alpha1.Condition { return &v1alpha1.Condition{ - Type: tlsConfigurationValid, + Type: typeTLSConfigurationValid, Status: v1alpha1.ConditionFalse, Reason: reasonInvalidTLSConfig, Message: message, } } -func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.ProviderConfig) *v1alpha1.Condition { +func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, config *upstreamldap.ProviderConfig) *v1alpha1.Condition { secretName := upstream.Spec.Bind.SecretName secret, err := c.secretInformer.Lister().Secrets(upstream.Namespace).Get(secretName) @@ -193,9 +228,9 @@ func (c *ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityPr } } - result.BindUsername = string(secret.Data[corev1.BasicAuthUsernameKey]) - result.BindPassword = string(secret.Data[corev1.BasicAuthPasswordKey]) - if len(result.BindUsername) == 0 || len(result.BindPassword) == 0 { + config.BindUsername = string(secret.Data[corev1.BasicAuthUsernameKey]) + config.BindPassword = string(secret.Data[corev1.BasicAuthPasswordKey]) + if len(config.BindUsername) == 0 || len(config.BindPassword) == 0 { return &v1alpha1.Condition{ Type: typeBindSecretValid, Status: v1alpha1.ConditionFalse, diff --git a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go index ca362db9..0a2bdbe0 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/ldap_upstream_watcher_test.go @@ -6,11 +6,13 @@ package upstreamwatcher import ( "context" "encoding/base64" + "errors" "fmt" "sort" "testing" "time" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -23,6 +25,7 @@ import ( pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions" "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/upstreamldap" @@ -165,13 +168,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { testCABundle := testCA.Bundle() testCABundleBase64Encoded := base64.StdEncoding.EncodeToString(testCABundle) - successfulDialer := &comparableDialer{ - f: func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) { - // TODO return a fake implementation of upstreamldap.Conn, or return an error for testing errors - return nil, nil - }, - } - validUpstream := &v1alpha1.LDAPIdentityProvider{ ObjectMeta: metav1.ObjectMeta{Name: testName, Namespace: testNamespace, Generation: 1234}, Spec: v1alpha1.LDAPIdentityProviderSpec{ @@ -206,30 +202,35 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { UsernameAttribute: testUsernameAttrName, UIDAttribute: testUIDAttrName, }, - Dialer: successfulDialer, // the dialer passed to the controller's constructor should have been passed through } tests := []struct { name string inputUpstreams []runtime.Object inputSecrets []runtime.Object - ldapDialer upstreamldap.LDAPDialer + setupMocks func(conn *mockldapconn.MockConn) + dialError error wantErr string wantResultingCache []*upstreamldap.ProviderConfig wantResultingUpstreams []v1alpha1.LDAPIdentityProvider }{ { - name: "no LDAPIdentityProvider upstreams clears the cache", + name: "no LDAPIdentityProvider upstreams clears the cache", + wantResultingCache: []*upstreamldap.ProviderConfig{}, }, { name: "one valid upstream updates the cache to include only that upstream", - ldapDialer: successfulDialer, inputUpstreams: []runtime.Object{validUpstream}, inputSecrets: []runtime.Object{&corev1.Secret{ ObjectMeta: metav1.ObjectMeta{Name: testSecretName, Namespace: testNamespace}, Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, + setupMocks: func(conn *mockldapconn.MockConn) { + // Should perform a test dial and bind. + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, wantResultingCache: []*upstreamldap.ProviderConfig{providerConfigForValidUpstream}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, @@ -244,6 +245,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Message: "loaded bind secret", ObservedGeneration: 1234, }, + { + Type: "LDAPConnectionValid", + Status: "True", + LastTransitionTime: now, + Reason: "Success", + Message: fmt.Sprintf(`successfully able to connect to "%s" and bind as user "%s"`, testHost, testBindUsername), + ObservedGeneration: 1234, + }, { Type: "TLSConfigurationValid", Status: "True", @@ -258,7 +267,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }, { name: "missing secret", - ldapDialer: successfulDialer, inputUpstreams: []runtime.Object{validUpstream}, inputSecrets: []runtime.Object{}, wantErr: controllerlib.ErrSyntheticRequeue.Error(), @@ -290,7 +298,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }, { name: "secret has wrong type", - ldapDialer: successfulDialer, inputUpstreams: []runtime.Object{validUpstream}, inputSecrets: []runtime.Object{&corev1.Secret{ ObjectMeta: metav1.ObjectMeta{Name: testSecretName, Namespace: testNamespace}, @@ -326,7 +333,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }, { name: "secret is missing key", - ldapDialer: successfulDialer, inputUpstreams: []runtime.Object{validUpstream}, inputSecrets: []runtime.Object{&corev1.Secret{ ObjectMeta: metav1.ObjectMeta{Name: testSecretName, Namespace: testNamespace}, @@ -360,8 +366,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }}, }, { - name: "CertificateAuthorityData is not base64 encoded", - ldapDialer: successfulDialer, + name: "CertificateAuthorityData is not base64 encoded", inputUpstreams: []runtime.Object{modifiedCopyOfValidUpstream(func(upstream *v1alpha1.LDAPIdentityProvider) { upstream.Spec.TLS.CertificateAuthorityData = "this-is-not-base64-encoded" })}, @@ -398,8 +403,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }}, }, { - name: "CertificateAuthorityData is not valid pem data", - ldapDialer: successfulDialer, + name: "CertificateAuthorityData is not valid pem data", inputUpstreams: []runtime.Object{modifiedCopyOfValidUpstream(func(upstream *v1alpha1.LDAPIdentityProvider) { upstream.Spec.TLS.CertificateAuthorityData = base64.StdEncoding.EncodeToString([]byte("this is not pem data")) })}, @@ -436,8 +440,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }}, }, { - name: "nil TLS configuration", - ldapDialer: successfulDialer, + name: "nil TLS configuration is valid", inputUpstreams: []runtime.Object{modifiedCopyOfValidUpstream(func(upstream *v1alpha1.LDAPIdentityProvider) { upstream.Spec.TLS = nil })}, @@ -446,6 +449,11 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, + setupMocks: func(conn *mockldapconn.MockConn) { + // Should perform a test dial and bind. + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, wantResultingCache: []*upstreamldap.ProviderConfig{ { Name: testName, @@ -459,7 +467,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { UsernameAttribute: testUsernameAttrName, UIDAttribute: testUIDAttrName, }, - Dialer: successfulDialer, // the dialer passed to the controller's constructor should have been passed through }, }, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ @@ -475,6 +482,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Message: "loaded bind secret", ObservedGeneration: 1234, }, + { + Type: "LDAPConnectionValid", + Status: "True", + LastTransitionTime: now, + Reason: "Success", + Message: fmt.Sprintf(`successfully able to connect to "%s" and bind as user "%s"`, testHost, testBindUsername), + ObservedGeneration: 1234, + }, { Type: "TLSConfigurationValid", Status: "True", @@ -488,8 +503,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }}, }, { - name: "non-nil TLS configuration with empty CertificateAuthorityData", - ldapDialer: successfulDialer, + name: "non-nil TLS configuration with empty CertificateAuthorityData is valid", inputUpstreams: []runtime.Object{modifiedCopyOfValidUpstream(func(upstream *v1alpha1.LDAPIdentityProvider) { upstream.Spec.TLS.CertificateAuthorityData = "" })}, @@ -498,6 +512,11 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, + setupMocks: func(conn *mockldapconn.MockConn) { + // Should perform a test dial and bind. + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, wantResultingCache: []*upstreamldap.ProviderConfig{ { Name: testName, @@ -511,7 +530,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { UsernameAttribute: testUsernameAttrName, UIDAttribute: testUIDAttrName, }, - Dialer: successfulDialer, // the dialer passed to the controller's constructor should have been passed through }, }, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ @@ -527,6 +545,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Message: "loaded bind secret", ObservedGeneration: 1234, }, + { + Type: "LDAPConnectionValid", + Status: "True", + LastTransitionTime: now, + Reason: "Success", + Message: fmt.Sprintf(`successfully able to connect to "%s" and bind as user "%s"`, testHost, testBindUsername), + ObservedGeneration: 1234, + }, { Type: "TLSConfigurationValid", Status: "True", @@ -540,8 +566,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }}, }, { - name: "one valid upstream and one invalid upstream updates the cache to include only the valid upstream", - ldapDialer: successfulDialer, + name: "one valid upstream and one invalid upstream updates the cache to include only the valid upstream", inputUpstreams: []runtime.Object{validUpstream, modifiedCopyOfValidUpstream(func(upstream *v1alpha1.LDAPIdentityProvider) { upstream.Name = "other-upstream" upstream.Generation = 42 @@ -552,6 +577,11 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Type: corev1.SecretTypeBasicAuth, Data: testValidSecretData, }}, + setupMocks: func(conn *mockldapconn.MockConn) { + // Should perform a test dial and bind for the one valid upstream configuration. + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, wantErr: controllerlib.ErrSyntheticRequeue.Error(), wantResultingCache: []*upstreamldap.ProviderConfig{providerConfigForValidUpstream}, wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{ @@ -592,6 +622,14 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { Message: "loaded bind secret", ObservedGeneration: 1234, }, + { + Type: "LDAPConnectionValid", + Status: "True", + LastTransitionTime: now, + Reason: "Success", + Message: fmt.Sprintf(`successfully able to connect to "%s" and bind as user "%s"`, testHost, testBindUsername), + ObservedGeneration: 1234, + }, { Type: "TLSConfigurationValid", Status: "True", @@ -605,11 +643,62 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { }, }, }, + { + name: "when testing the connection to the LDAP server fails then the upstream is not added to the cache", + inputUpstreams: []runtime.Object{validUpstream}, + inputSecrets: []runtime.Object{&corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{Name: testSecretName, Namespace: testNamespace}, + Type: corev1.SecretTypeBasicAuth, + Data: testValidSecretData, + }}, + setupMocks: func(conn *mockldapconn.MockConn) { + // Should perform a test dial and bind. + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1).Return(errors.New("some bind error")) + conn.EXPECT().Close().Times(1) + }, + wantErr: controllerlib.ErrSyntheticRequeue.Error(), + wantResultingCache: []*upstreamldap.ProviderConfig{}, + wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ + ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, + Status: v1alpha1.LDAPIdentityProviderStatus{ + Phase: "Error", + Conditions: []v1alpha1.Condition{ + { + Type: "BindSecretValid", + Status: "True", + LastTransitionTime: now, + Reason: "Success", + Message: "loaded bind secret", + ObservedGeneration: 1234, + }, + { + Type: "LDAPConnectionValid", + Status: "False", + LastTransitionTime: now, + Reason: "LDAPConnectionError", + Message: fmt.Sprintf( + `could not successfully connect to "%s" and bind as user "%s: error binding as "%s": some bind error`, + testHost, testBindUsername, testBindUsername), + ObservedGeneration: 1234, + }, + { + Type: "TLSConfigurationValid", + Status: "True", + LastTransitionTime: now, + Reason: "Success", + Message: "loaded TLS configuration", + ObservedGeneration: 1234, + }, + }, + }, + }}, + }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() + fakePinnipedClient := pinnipedfake.NewSimpleClientset(tt.inputUpstreams...) pinnipedInformers := pinnipedinformers.NewSharedInformerFactory(fakePinnipedClient, 0) fakeKubeClient := fake.NewSimpleClientset(tt.inputSecrets...) @@ -619,9 +708,24 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}), }) + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + conn := mockldapconn.NewMockConn(ctrl) + if tt.setupMocks != nil { + tt.setupMocks(conn) + } + + dialer := &comparableDialer{f: upstreamldap.LDAPDialerFunc(func(ctx context.Context, _ string) (upstreamldap.Conn, error) { + if tt.dialError != nil { + return nil, tt.dialError + } + return conn, nil + })} + controller := NewLDAPUpstreamWatcherController( cache, - successfulDialer, + dialer, fakePinnipedClient, pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders(), kubeInformers.Core().V1().Secrets(), @@ -647,7 +751,11 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) for i := range actualIDPList { actualIDP := actualIDPList[i].(*upstreamldap.Provider) - require.Equal(t, *tt.wantResultingCache[i], actualIDP.GetConfig()) + copyOfExpectedValue := *tt.wantResultingCache[i] // copy before edit to avoid race because these tests are run in parallel + // The dialer that was passed in to the controller's constructor should always have been + // passed through to the provider. + copyOfExpectedValue.Dialer = dialer + require.Equal(t, copyOfExpectedValue, actualIDP.GetConfig()) } actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{}) @@ -660,13 +768,6 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { // Require each separately to get a nice diff when the test fails. require.Equal(t, tt.wantResultingUpstreams[i], normalizedActualUpstreams[i]) } - - // Running the sync() a second time should be idempotent, and should return the same error. - if err := controllerlib.TestSync(t, controller, syncCtx); tt.wantErr != "" { - require.EqualError(t, err, tt.wantErr) - } else { - require.NoError(t, err) - } }) } } diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index f27551c7..e7694cf5 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -182,10 +182,24 @@ func (p *Provider) GetURL() string { // TestConnection provides a method for testing the connection and bind settings. It performs a dial and bind // and returns any errors that we encountered. -func (p *Provider) TestConnection(ctx context.Context) (*authenticator.Response, error) { - _, _ = p.dial(ctx) - // TODO implement me - return nil, nil +func (p *Provider) TestConnection(ctx context.Context) error { + err := p.validateConfig() + if err != nil { + return err + } + + conn, err := p.dial(ctx) + if err != nil { + return fmt.Errorf(`error dialing host "%s": %w`, p.c.Host, err) + } + defer conn.Close() + + err = conn.Bind(p.c.BindUsername, p.c.BindPassword) + if err != nil { + return fmt.Errorf(`error binding as "%s": %w`, p.c.BindUsername, err) + } + + return nil } // TestAuthenticateUser provides a method for testing all of the Provider settings in a kind of dry run of @@ -199,9 +213,9 @@ func (p *Provider) TestAuthenticateUser(ctx context.Context, testUsername string // Authenticate a user and return their mapped username, groups, and UID. Implements authenticators.UserAuthenticator. func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) { - if p.c.UserSearch.UsernameAttribute == distinguishedNameAttributeName && len(p.c.UserSearch.Filter) == 0 { - // LDAP search filters do not allow searching by DN. - return nil, false, fmt.Errorf(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`) + err := p.validateConfig() + if err != nil { + return nil, false, err } if len(username) == 0 { @@ -239,6 +253,14 @@ func (p *Provider) AuthenticateUser(ctx context.Context, username, password stri return response, true, nil } +func (p *Provider) validateConfig() error { + if p.c.UserSearch.UsernameAttribute == distinguishedNameAttributeName && len(p.c.UserSearch.Filter) == 0 { + // LDAP search filters do not allow searching by DN, so we would have no reasonable default for Filter. + return fmt.Errorf(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`) + } + return nil +} + func (p *Provider) searchAndBindUser(conn Conn, username string, password string) (string, string, error) { searchResult, err := conn.Search(p.userSearchRequest(username)) if err != nil { diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index b6b21846..c2cb7de6 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -608,6 +608,101 @@ func TestAuthenticateUser(t *testing.T) { } } +func TestTestConnection(t *testing.T) { + providerConfig := func(editFunc func(p *ProviderConfig)) *ProviderConfig { + config := &ProviderConfig{ + Name: "some-provider-name", + Host: testHost, + CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test + BindUsername: testBindUsername, + BindPassword: testBindPassword, + UserSearch: UserSearchConfig{}, // not used by TestConnection + } + if editFunc != nil { + editFunc(config) + } + return config + } + + tests := []struct { + name string + providerConfig *ProviderConfig + setupMocks func(conn *mockldapconn.MockConn) + dialError error + wantError string + wantToSkipDial bool + }{ + { + name: "happy path", + providerConfig: providerConfig(nil), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, + }, + { + name: "when dial fails", + providerConfig: providerConfig(nil), + dialError: errors.New("some dial error"), + wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), + }, + { + name: "when binding as the bind user returns an error", + providerConfig: providerConfig(nil), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantError: fmt.Sprintf(`error binding as "%s": some bind error`, testBindUsername), + }, + { + name: "when the config is invalid", + providerConfig: providerConfig(func(p *ProviderConfig) { + // This particular combination of options is not allowed. + p.UserSearch.UsernameAttribute = "dn" + p.UserSearch.Filter = "" + }), + wantToSkipDial: true, + wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, + }, + } + + for _, test := range tests { + tt := test + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + conn := mockldapconn.NewMockConn(ctrl) + if tt.setupMocks != nil { + tt.setupMocks(conn) + } + + dialWasAttempted := false + tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + dialWasAttempted = true + require.Equal(t, tt.providerConfig.Host, hostAndPort) + if tt.dialError != nil { + return nil, tt.dialError + } + return conn, nil + }) + + provider := New(*tt.providerConfig) + err := provider.TestConnection(context.Background()) + + require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) + + switch { + case tt.wantError != "": + require.EqualError(t, err, tt.wantError) + default: + require.NoError(t, err) + } + }) + } +} + func TestGetConfig(t *testing.T) { c := ProviderConfig{ Name: "original-provider-name",