Func ldap.Conn.Close() now returns an error

- https://github.com/go-ldap/ldap/compare/v3.4.4...v3.4.5
This commit is contained in:
Joshua Casey 2023-06-16 13:24:45 -05:00
parent dbbaf9b969
commit 67cd5e70c2
2 changed files with 14 additions and 6 deletions

View File

@ -53,9 +53,10 @@ func (mr *MockConnMockRecorder) Bind(arg0, arg1 interface{}) *gomock.Call {
} }
// Close mocks base method. // Close mocks base method.
func (m *MockConn) Close() { func (m *MockConn) Close() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Close") m.ctrl.Call(m, "Close")
return nil
} }
// Close indicates an expected call of Close. // Close indicates an expected call of Close.

View File

@ -49,7 +49,7 @@ type Conn interface {
SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error) SearchWithPaging(searchRequest *ldap.SearchRequest, pagingSize uint32) (*ldap.SearchResult, error)
Close() Close() error
} }
// Our Conn type is subset of the ldap.Client interface, which is implemented by ldap.Conn. // Our Conn type is subset of the ldap.Client interface, which is implemented by ldap.Conn.
@ -181,6 +181,13 @@ func (p *Provider) GetConfig() ProviderConfig {
return p.c return p.c
} }
func closeAndLogError(conn Conn, doingWhat string) {
err := conn.Close()
if err != nil {
plog.Error(fmt.Sprintf("error closing LDAP connection when %s", doingWhat), err)
}
}
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.RefreshAttributes) ([]string, error) { func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.RefreshAttributes) ([]string, error) {
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
@ -190,7 +197,7 @@ func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes p
if err != nil { if err != nil {
return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
} }
defer conn.Close() defer closeAndLogError(conn, "refreshing connection")
err = conn.Bind(p.c.BindUsername, p.c.BindPassword) err = conn.Bind(p.c.BindUsername, p.c.BindPassword)
if err != nil { if err != nil {
@ -402,7 +409,7 @@ func (p *Provider) TestConnection(ctx context.Context) error {
if err != nil { if err != nil {
return fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) return fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
} }
defer conn.Close() defer closeAndLogError(conn, "testing connection")
err = conn.Bind(p.c.BindUsername, p.c.BindPassword) err = conn.Bind(p.c.BindUsername, p.c.BindPassword)
if err != nil { if err != nil {
@ -453,7 +460,7 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, gr
p.traceAuthFailure(t, err) p.traceAuthFailure(t, err)
return nil, false, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) return nil, false, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
} }
defer conn.Close() defer closeAndLogError(conn, "authenticating user")
err = conn.Bind(p.c.BindUsername, p.c.BindPassword) err = conn.Bind(p.c.BindUsername, p.c.BindPassword)
if err != nil { if err != nil {
@ -534,7 +541,7 @@ func (p *Provider) SearchForDefaultNamingContext(ctx context.Context) (string, e
p.traceSearchBaseDiscoveryFailure(t, err) p.traceSearchBaseDiscoveryFailure(t, err)
return "", fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) return "", fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
} }
defer conn.Close() defer closeAndLogError(conn, "searching for default naming context")
err = conn.Bind(p.c.BindUsername, p.c.BindPassword) err = conn.Bind(p.c.BindUsername, p.c.BindPassword)
if err != nil { if err != nil {