upstreamldap.New() now supports a StartTLS config option

- This enhances our LDAP client code to make it possible to optionally
  dial an LDAP server without TLS and then use StartTLS to upgrade
  the connection to TLS.
- The controller for LDAPIdentityProviders is not using this option
  yet. That will come in a future commit.
This commit is contained in:
Ryan Richard 2021-05-19 17:17:44 -07:00
parent 94d6b76958
commit 025b37f839
5 changed files with 332 additions and 73 deletions

View File

@ -154,6 +154,7 @@ func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *
config := &upstreamldap.ProviderConfig{ config := &upstreamldap.ProviderConfig{
Name: upstream.Name, Name: upstream.Name,
Host: spec.Host, Host: spec.Host,
ConnectionProtocol: upstreamldap.TLS,
UserSearch: upstreamldap.UserSearchConfig{ UserSearch: upstreamldap.UserSearchConfig{
Base: spec.UserSearch.Base, Base: spec.UserSearch.Base,
Filter: spec.UserSearch.Filter, Filter: spec.UserSearch.Filter,

View File

@ -199,6 +199,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
providerConfigForValidUpstream := &upstreamldap.ProviderConfig{ providerConfigForValidUpstream := &upstreamldap.ProviderConfig{
Name: testName, Name: testName,
Host: testHost, Host: testHost,
ConnectionProtocol: upstreamldap.TLS,
CABundle: testCABundle, CABundle: testCABundle,
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,
@ -444,6 +445,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
{ {
Name: testName, Name: testName,
Host: testHost, Host: testHost,
ConnectionProtocol: upstreamldap.TLS,
CABundle: nil, CABundle: nil,
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,
@ -495,6 +497,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
{ {
Name: testName, Name: testName,
Host: testHost, Host: testHost,
ConnectionProtocol: upstreamldap.TLS,
CABundle: nil, CABundle: nil,
BindUsername: testBindUsername, BindUsername: testBindUsername,
BindPassword: testBindPassword, BindPassword: testBindPassword,

View File

@ -60,6 +60,13 @@ func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, err
return f(ctx, hostAndPort) return f(ctx, hostAndPort)
} }
type LDAPConnectionProtocol string
const (
StartTLS = LDAPConnectionProtocol("StartTLS")
TLS = LDAPConnectionProtocol("TLS")
)
// ProviderConfig includes all of the settings for connection and searching for users and groups in // ProviderConfig includes all of the settings for connection and searching for users and groups in
// the upstream LDAP IDP. It also provides methods for testing the connection and performing logins. // the upstream LDAP IDP. It also provides methods for testing the connection and performing logins.
// The nested structs are not pointer fields to enable deep copy on function params and return values. // The nested structs are not pointer fields to enable deep copy on function params and return values.
@ -71,6 +78,9 @@ type ProviderConfig struct {
// the default LDAP port will be used. // the default LDAP port will be used.
Host string Host string
// ConnectionProtocol determines how to establish the connection to the server. Either StartTLS or TLS.
ConnectionProtocol LDAPConnectionProtocol
// PEM-encoded CA cert bundle to trust when connecting to the LDAP server. Can be nil. // PEM-encoded CA cert bundle to trust when connecting to the LDAP server. Can be nil.
CABundle []byte CABundle []byte
@ -137,33 +147,38 @@ func (p *Provider) GetConfig() ProviderConfig {
} }
func (p *Provider) dial(ctx context.Context) (Conn, error) { func (p *Provider) dial(ctx context.Context) (Conn, error) {
hostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort) tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort)
if err != nil { if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err) return nil, ldap.NewError(ldap.ErrorNetwork, err)
} }
if p.c.Dialer != nil {
return p.c.Dialer.Dial(ctx, hostAndPort) startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort)
} if err != nil {
return p.dialTLS(ctx, hostAndPort) return nil, ldap.NewError(ldap.ErrorNetwork, err)
} }
// dialTLS is the default implementation of the Dialer, used when Dialer is nil. switch {
case p.c.Dialer != nil:
return p.c.Dialer.Dial(ctx, tlsHostAndPort)
case p.c.ConnectionProtocol == TLS:
return p.dialTLS(ctx, tlsHostAndPort)
case p.c.ConnectionProtocol == StartTLS:
return p.dialStartTLS(ctx, startTLSHostAndPort)
default:
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol"))
}
}
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS.
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
// so we implement it ourselves, heavily inspired by ldap.DialURL. // so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
var rootCAs *x509.CertPool tlsConfig, err := p.tlsConfig()
if p.c.CABundle != nil { if err != nil {
rootCAs = x509.NewCertPool() return nil, ldap.NewError(ldap.ErrorNetwork, err)
if !rootCAs.AppendCertsFromPEM(p.c.CABundle) {
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle"))
}
} }
dialer := &tls.Dialer{Config: &tls.Config{ dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig}
MinVersion: tls.VersionTLS12,
RootCAs: rootCAs,
}}
c, err := dialer.DialContext(ctx, "tcp", hostAndPort) c, err := dialer.DialContext(ctx, "tcp", hostAndPort)
if err != nil { if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err) return nil, ldap.NewError(ldap.ErrorNetwork, err)
@ -174,6 +189,52 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error
return conn, nil return conn, nil
} }
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS.
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
// so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) {
tlsConfig, err := p.tlsConfig()
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
host, err := hostWithoutPort(hostAndPort)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
// Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS.
tlsConfig.ServerName = host
c, err := netDialer().DialContext(ctx, "tcp", hostAndPort)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
conn := ldap.NewConn(c, false)
conn.Start()
err = conn.StartTLS(tlsConfig)
if err != nil {
return nil, err
}
return conn, nil
}
func netDialer() *net.Dialer {
return &net.Dialer{Timeout: time.Minute}
}
func (p *Provider) tlsConfig() (*tls.Config, error) {
var rootCAs *x509.CertPool
if p.c.CABundle != nil {
rootCAs = x509.NewCertPool()
if !rootCAs.AppendCertsFromPEM(p.c.CABundle) {
return nil, fmt.Errorf("could not parse CA bundle")
}
}
return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil
}
// Adds the default port if hostAndPort did not already include a port. // Adds the default port if hostAndPort did not already include a port.
func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) { func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) {
host, port, err := net.SplitHostPort(hostAndPort) host, port, err := net.SplitHostPort(hostAndPort)
@ -188,7 +249,7 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string,
switch { switch {
case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"): case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
// don't add extra square brackets to an IPv6 address that already has them // don't add extra square brackets to an IPv6 address that already has them
return host + ":" + port, nil return fmt.Sprintf("%s:%s", host, port), nil
case port != "": case port != "":
return net.JoinHostPort(host, port), nil return net.JoinHostPort(host, port), nil
default: default:
@ -196,6 +257,22 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string,
} }
} }
// Strip the port from a host or host:port.
func hostWithoutPort(hostAndPort string) (string, error) {
host, _, err := net.SplitHostPort(hostAndPort)
if err != nil {
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
return hostAndPort, nil
}
return "", err // hostAndPort argument was not parsable
}
if strings.HasPrefix(hostAndPort, "[") {
// it was an IPv6 address, so preserve the square brackets.
return fmt.Sprintf("[%s]", host), nil
}
return host, nil
}
// A name for this upstream provider. // A name for this upstream provider.
func (p *Provider) GetName() string { func (p *Provider) GetName() string {
return p.c.Name return p.c.Name

View File

@ -1132,6 +1132,7 @@ func TestRealTLSDialing(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
host string host string
connProto LDAPConnectionProtocol
caBundle []byte caBundle []byte
context context.Context context context.Context
wantError string wantError string
@ -1140,19 +1141,46 @@ func TestRealTLSDialing(t *testing.T) {
name: "happy path", name: "happy path",
host: testServerHostAndPort, host: testServerHostAndPort,
caBundle: []byte(testServerCABundle), caBundle: []byte(testServerCABundle),
connProto: TLS,
context: context.Background(), context: context.Background(),
}, },
{ {
name: "invalid CA bundle", name: "invalid CA bundle with TLS",
host: testServerHostAndPort, host: testServerHostAndPort,
caBundle: []byte("not a ca bundle"), caBundle: []byte("not a ca bundle"),
connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`, wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`,
}, },
{
name: "invalid CA bundle with StartTLS",
host: testServerHostAndPort,
caBundle: []byte("not a ca bundle"),
connProto: StartTLS,
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`,
},
{
name: "invalid host with TLS",
host: "this:is:not:a:valid:hostname",
caBundle: []byte(testServerCABundle),
connProto: TLS,
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
},
{
name: "invalid host with StartTLS",
host: "this:is:not:a:valid:hostname",
caBundle: []byte(testServerCABundle),
connProto: StartTLS,
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
},
{ {
name: "missing CA bundle when it is required because the host is not using a trusted CA", name: "missing CA bundle when it is required because the host is not using a trusted CA",
host: testServerHostAndPort, host: testServerHostAndPort,
caBundle: nil, caBundle: nil,
connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, wantError: `LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`,
}, },
@ -1161,6 +1189,7 @@ func TestRealTLSDialing(t *testing.T) {
// This is assuming that this port was not reclaimed by another app since the test setup ran. Seems safe enough. // This is assuming that this port was not reclaimed by another app since the test setup ran. Seems safe enough.
host: recentlyClaimedHostAndPort, host: recentlyClaimedHostAndPort,
caBundle: []byte(testServerCABundle), caBundle: []byte(testServerCABundle),
connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort), wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort),
}, },
@ -1168,25 +1197,35 @@ func TestRealTLSDialing(t *testing.T) {
name: "pays attention to the passed context", name: "pays attention to the passed context",
host: testServerHostAndPort, host: testServerHostAndPort,
caBundle: []byte(testServerCABundle), caBundle: []byte(testServerCABundle),
connProto: TLS,
context: alreadyCancelledContext, context: alreadyCancelledContext,
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort), wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort),
}, },
{
name: "unsupported connection protocol",
host: testServerHostAndPort,
caBundle: []byte(testServerCABundle),
connProto: "bad usage of this type",
context: alreadyCancelledContext,
wantError: `LDAP Result Code 200 "Network Error": did not specify valid ConnectionProtocol`,
},
} }
for _, test := range tests { for _, test := range tests {
test := test tt := test
t.Run(test.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
provider := New(ProviderConfig{ provider := New(ProviderConfig{
Host: test.host, Host: tt.host,
CABundle: test.caBundle, CABundle: tt.caBundle,
Dialer: nil, // this test is for the default (production) dialer ConnectionProtocol: tt.connProto,
Dialer: nil, // this test is for the default (production) TLS dialer
}) })
conn, err := provider.dial(test.context) conn, err := provider.dial(tt.context)
if conn != nil { if conn != nil {
defer conn.Close() defer conn.Close()
} }
if test.wantError != "" { if tt.wantError != "" {
require.Nil(t, conn) require.Nil(t, conn)
require.EqualError(t, err, test.wantError) require.EqualError(t, err, tt.wantError)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conn) require.NotNil(t, conn)
@ -1231,6 +1270,12 @@ func TestHostAndPortWithDefaultPort(t *testing.T) {
defaultPort: "", defaultPort: "",
wantHostAndPort: "host.example.com", wantHostAndPort: "host.example.com",
}, },
{
name: "host has port and default port is empty",
hostAndPort: "host.example.com:42",
defaultPort: "",
wantHostAndPort: "host.example.com:42",
},
{ {
name: "IPv6 host already has port", name: "IPv6 host already has port",
hostAndPort: "[::1%lo0]:80", hostAndPort: "[::1%lo0]:80",
@ -1257,15 +1302,63 @@ func TestHostAndPortWithDefaultPort(t *testing.T) {
}, },
} }
for _, test := range tests { for _, test := range tests {
test := test tt := test
t.Run(test.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
hostAndPort, err := hostAndPortWithDefaultPort(test.hostAndPort, test.defaultPort) hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort)
if test.wantError != "" { if tt.wantError != "" {
require.EqualError(t, err, test.wantError) require.EqualError(t, err, tt.wantError)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
} }
require.Equal(t, test.wantHostAndPort, hostAndPort) require.Equal(t, tt.wantHostAndPort, hostAndPort)
})
}
}
// Test various cases of host and port parsing.
func TestHostWithoutPort(t *testing.T) {
tests := []struct {
name string
hostAndPort string
wantError string
wantHostAndPort string
}{
{
name: "host already has port",
hostAndPort: "host.example.com:99",
wantHostAndPort: "host.example.com",
},
{
name: "host does not have port",
hostAndPort: "host.example.com",
wantHostAndPort: "host.example.com",
},
{
name: "IPv6 host already has port",
hostAndPort: "[::1%lo0]:80",
wantHostAndPort: "[::1%lo0]",
},
{
name: "IPv6 host does not have port",
hostAndPort: "[::1%lo0]",
wantHostAndPort: "[::1%lo0]",
},
{
name: "host is not valid",
hostAndPort: "host.example.com:port1:port2",
wantError: "address host.example.com:port1:port2: too many colons in address",
},
}
for _, test := range tests {
tt := test
t.Run(tt.name, func(t *testing.T) {
hostAndPort, err := hostWithoutPort(tt.hostAndPort)
if tt.wantError != "" {
require.EqualError(t, err, tt.wantError)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantHostAndPort, hostAndPort)
}) })
} }
} }

View File

@ -37,15 +37,19 @@ func TestLDAPSearch(t *testing.T) {
cancelFunc() // this will send SIGKILL to the subprocess, just in case cancelFunc() // this will send SIGKILL to the subprocess, just in case
}) })
hostPorts := findRecentlyUnusedLocalhostPorts(t, 2) localhostPorts := findRecentlyUnusedLocalhostPorts(t, 3)
ldapHostPort := hostPorts[0] ldapLocalhostPort := localhostPorts[0]
unusedHostPort := hostPorts[1] ldapsLocalhostPort := localhostPorts[1]
unusedLocalhostPort := localhostPorts[2]
// Expose the the test LDAP server's TLS port on the localhost. // Expose the the test LDAP server's TLS port on the localhost.
startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace) startKubectlPortForward(ctx, t, ldapsLocalhostPort, "ldaps", "ldap", env.ToolsNamespace)
// Expose the the test LDAP server's StartTLS port on the localhost.
startKubectlPortForward(ctx, t, ldapLocalhostPort, "ldap", "ldap", env.ToolsNamespace)
providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig { providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig {
providerConfig := defaultProviderConfig(env, ldapHostPort) providerConfig := defaultProviderConfig(env, ldapsLocalhostPort)
if editFunc != nil { if editFunc != nil {
editFunc(providerConfig) editFunc(providerConfig)
} }
@ -64,7 +68,7 @@ func TestLDAPSearch(t *testing.T) {
wantUnauthenticated bool wantUnauthenticated bool
}{ }{
{ {
name: "happy path", name: "happy path with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(nil)), provider: upstreamldap.New(*providerConfig(nil)),
@ -72,6 +76,18 @@ func TestLDAPSearch(t *testing.T) {
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}}, User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}},
}, },
}, },
{
name: "happy path with StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.Host = "127.0.0.1:" + ldapLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS
})),
wantAuthResponse: &authenticator.Response{
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}},
},
},
{ {
name: "using a different user search base", name: "using a different user search base",
username: "pinny", username: "pinny",
@ -251,6 +267,17 @@ func TestLDAPSearch(t *testing.T) {
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })),
wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
}, },
{
name: "when the bind user username is wrong with StartTLS: example of an error after successful connection with StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.Host = "127.0.0.1:" + ldapLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS
p.BindUsername = "cn=wrong,dc=pinniped,dc=dev"
})),
wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
},
{ {
name: "when the end user password is wrong", name: "when the end user password is wrong",
username: "pinny", username: "pinny",
@ -296,32 +323,89 @@ func TestLDAPSearch(t *testing.T) {
wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `, wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `,
}, },
{ {
name: "when the server is unreachable", name: "when the server is unreachable with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedHostPort })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedLocalhostPort })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedHostPort, unusedHostPort), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort),
}, },
{ {
name: "when the server is not parsable", name: "when the server is unreachable with StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.Host = "127.0.0.1:" + unusedLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS
})),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort),
},
{
name: "when the server is not parsable with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
}, },
{ {
name: "when the CA bundle is not parsable", name: "when the server is not parsable with StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.Host = "127.0.0.1:" + ldapLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS
p.Host = "too:many:ports"
})),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
},
{
name: "when the CA bundle is not parsable with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapHostPort), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapsLocalhostPort),
}, },
{ {
name: "when the CA bundle does not cause the host to be trusted", name: "when the CA bundle is not parsable with StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.Host = "127.0.0.1:" + ldapLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS
p.CABundle = []byte("invalid-pem")
})),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapLocalhostPort),
},
{
name: "when the CA bundle does not cause the host to be trusted with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapHostPort), wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapsLocalhostPort),
},
{
name: "when the CA bundle does not cause the host to be trusted with StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.Host = "127.0.0.1:" + ldapLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS
p.CABundle = nil
})),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": TLS handshake failed (x509: certificate signed by unknown authority)`, ldapLocalhostPort),
},
{
name: "when trying to use TLS to connect to a port which only supports StartTLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + ldapLocalhostPort })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": EOF`, ldapLocalhostPort),
},
{
name: "when trying to use StartTLS to connect to a port which only supports TLS",
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.ConnectionProtocol = upstreamldap.StartTLS })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": unable to read LDAP response packet: unexpected EOF`, ldapsLocalhostPort),
}, },
{ {
name: "when the UsernameAttribute attribute has multiple values in the entry", name: "when the UsernameAttribute attribute has multiple values in the entry",
@ -541,10 +625,11 @@ type authUserResult struct {
err error err error
} }
func defaultProviderConfig(env *library.TestEnv, ldapHostPort string) *upstreamldap.ProviderConfig { func defaultProviderConfig(env *library.TestEnv, port string) *upstreamldap.ProviderConfig {
return &upstreamldap.ProviderConfig{ return &upstreamldap.ProviderConfig{
Name: "test-ldap-provider", Name: "test-ldap-provider",
Host: "127.0.0.1:" + ldapHostPort, Host: "127.0.0.1:" + port,
ConnectionProtocol: upstreamldap.TLS,
CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle), CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle),
BindUsername: "cn=admin,dc=pinniped,dc=dev", BindUsername: "cn=admin,dc=pinniped,dc=dev",
BindPassword: "password", BindPassword: "password",