diff --git a/internal/mocks/mockldapconn/mockldapconn.go b/internal/mocks/mockldapconn/mockldapconn.go index 0caa243a..0054661b 100644 --- a/internal/mocks/mockldapconn/mockldapconn.go +++ b/internal/mocks/mockldapconn/mockldapconn.go @@ -53,9 +53,10 @@ func (mr *MockConnMockRecorder) Bind(arg0, arg1 interface{}) *gomock.Call { } // Close mocks base method. -func (m *MockConn) Close() { +func (m *MockConn) Close() error { m.ctrl.T.Helper() m.ctrl.Call(m, "Close") + return nil } // Close indicates an expected call of Close. diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 80d02b00..bbf645c1 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -49,7 +49,7 @@ type Conn interface { SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) - Close() + Close() error } // Our Conn type is subset of the ldap.Client interface, which is implemented by ldap.Conn. @@ -181,6 +181,13 @@ func (p *Provider) GetConfig() ProviderConfig { return p.c } +func closeAndLogError(conn Conn, doingWhat string) { + err := conn.Close() + if err != nil { + plog.Error(fmt.Sprintf("error closing LDAP connection when %s", doingWhat), err) + } +} + func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.RefreshAttributes) ([]string, error) { t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches @@ -190,7 +197,7 @@ func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes p if err != nil { return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) } - defer conn.Close() + defer closeAndLogError(conn, "refreshing connection") err = conn.Bind(p.c.BindUsername, p.c.BindPassword) if err != nil { @@ -402,7 +409,7 @@ func (p *Provider) TestConnection(ctx context.Context) error { if err != nil { return fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) } - defer conn.Close() + defer closeAndLogError(conn, "testing connection") err = conn.Bind(p.c.BindUsername, p.c.BindPassword) if err != nil { @@ -453,7 +460,7 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, gr p.traceAuthFailure(t, err) return nil, false, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) } - defer conn.Close() + defer closeAndLogError(conn, "authenticating user") err = conn.Bind(p.c.BindUsername, p.c.BindPassword) if err != nil { @@ -534,7 +541,7 @@ func (p *Provider) SearchForDefaultNamingContext(ctx context.Context) (string, e p.traceSearchBaseDiscoveryFailure(t, err) return "", fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) } - defer conn.Close() + defer closeAndLogError(conn, "searching for default naming context") err = conn.Bind(p.c.BindUsername, p.c.BindPassword) if err != nil {