upstreamldap.New() now supports a StartTLS config option
- This enhances our LDAP client code to make it possible to optionally dial an LDAP server without TLS and then use StartTLS to upgrade the connection to TLS. - The controller for LDAPIdentityProviders is not using this option yet. That will come in a future commit.
This commit is contained in:
parent
94d6b76958
commit
025b37f839
@ -152,8 +152,9 @@ func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *
|
|||||||
spec := upstream.Spec
|
spec := upstream.Spec
|
||||||
|
|
||||||
config := &upstreamldap.ProviderConfig{
|
config := &upstreamldap.ProviderConfig{
|
||||||
Name: upstream.Name,
|
Name: upstream.Name,
|
||||||
Host: spec.Host,
|
Host: spec.Host,
|
||||||
|
ConnectionProtocol: upstreamldap.TLS,
|
||||||
UserSearch: upstreamldap.UserSearchConfig{
|
UserSearch: upstreamldap.UserSearchConfig{
|
||||||
Base: spec.UserSearch.Base,
|
Base: spec.UserSearch.Base,
|
||||||
Filter: spec.UserSearch.Filter,
|
Filter: spec.UserSearch.Filter,
|
||||||
|
@ -197,11 +197,12 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
providerConfigForValidUpstream := &upstreamldap.ProviderConfig{
|
providerConfigForValidUpstream := &upstreamldap.ProviderConfig{
|
||||||
Name: testName,
|
Name: testName,
|
||||||
Host: testHost,
|
Host: testHost,
|
||||||
CABundle: testCABundle,
|
ConnectionProtocol: upstreamldap.TLS,
|
||||||
BindUsername: testBindUsername,
|
CABundle: testCABundle,
|
||||||
BindPassword: testBindPassword,
|
BindUsername: testBindUsername,
|
||||||
|
BindPassword: testBindPassword,
|
||||||
UserSearch: upstreamldap.UserSearchConfig{
|
UserSearch: upstreamldap.UserSearchConfig{
|
||||||
Base: testUserSearchBase,
|
Base: testUserSearchBase,
|
||||||
Filter: testUserSearchFilter,
|
Filter: testUserSearchFilter,
|
||||||
@ -442,11 +443,12 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantResultingCache: []*upstreamldap.ProviderConfig{
|
wantResultingCache: []*upstreamldap.ProviderConfig{
|
||||||
{
|
{
|
||||||
Name: testName,
|
Name: testName,
|
||||||
Host: testHost,
|
Host: testHost,
|
||||||
CABundle: nil,
|
ConnectionProtocol: upstreamldap.TLS,
|
||||||
BindUsername: testBindUsername,
|
CABundle: nil,
|
||||||
BindPassword: testBindPassword,
|
BindUsername: testBindUsername,
|
||||||
|
BindPassword: testBindPassword,
|
||||||
UserSearch: upstreamldap.UserSearchConfig{
|
UserSearch: upstreamldap.UserSearchConfig{
|
||||||
Base: testUserSearchBase,
|
Base: testUserSearchBase,
|
||||||
Filter: testUserSearchFilter,
|
Filter: testUserSearchFilter,
|
||||||
@ -493,11 +495,12 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantResultingCache: []*upstreamldap.ProviderConfig{
|
wantResultingCache: []*upstreamldap.ProviderConfig{
|
||||||
{
|
{
|
||||||
Name: testName,
|
Name: testName,
|
||||||
Host: testHost,
|
Host: testHost,
|
||||||
CABundle: nil,
|
ConnectionProtocol: upstreamldap.TLS,
|
||||||
BindUsername: testBindUsername,
|
CABundle: nil,
|
||||||
BindPassword: testBindPassword,
|
BindUsername: testBindUsername,
|
||||||
|
BindPassword: testBindPassword,
|
||||||
UserSearch: upstreamldap.UserSearchConfig{
|
UserSearch: upstreamldap.UserSearchConfig{
|
||||||
Base: testUserSearchBase,
|
Base: testUserSearchBase,
|
||||||
Filter: testUserSearchFilter,
|
Filter: testUserSearchFilter,
|
||||||
|
@ -60,6 +60,13 @@ func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, err
|
|||||||
return f(ctx, hostAndPort)
|
return f(ctx, hostAndPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type LDAPConnectionProtocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StartTLS = LDAPConnectionProtocol("StartTLS")
|
||||||
|
TLS = LDAPConnectionProtocol("TLS")
|
||||||
|
)
|
||||||
|
|
||||||
// ProviderConfig includes all of the settings for connection and searching for users and groups in
|
// ProviderConfig includes all of the settings for connection and searching for users and groups in
|
||||||
// the upstream LDAP IDP. It also provides methods for testing the connection and performing logins.
|
// the upstream LDAP IDP. It also provides methods for testing the connection and performing logins.
|
||||||
// The nested structs are not pointer fields to enable deep copy on function params and return values.
|
// The nested structs are not pointer fields to enable deep copy on function params and return values.
|
||||||
@ -71,6 +78,9 @@ type ProviderConfig struct {
|
|||||||
// the default LDAP port will be used.
|
// the default LDAP port will be used.
|
||||||
Host string
|
Host string
|
||||||
|
|
||||||
|
// ConnectionProtocol determines how to establish the connection to the server. Either StartTLS or TLS.
|
||||||
|
ConnectionProtocol LDAPConnectionProtocol
|
||||||
|
|
||||||
// PEM-encoded CA cert bundle to trust when connecting to the LDAP server. Can be nil.
|
// PEM-encoded CA cert bundle to trust when connecting to the LDAP server. Can be nil.
|
||||||
CABundle []byte
|
CABundle []byte
|
||||||
|
|
||||||
@ -137,33 +147,38 @@ func (p *Provider) GetConfig() ProviderConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) dial(ctx context.Context) (Conn, error) {
|
func (p *Provider) dial(ctx context.Context) (Conn, error) {
|
||||||
hostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort)
|
tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
}
|
}
|
||||||
if p.c.Dialer != nil {
|
|
||||||
return p.c.Dialer.Dial(ctx, hostAndPort)
|
startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case p.c.Dialer != nil:
|
||||||
|
return p.c.Dialer.Dial(ctx, tlsHostAndPort)
|
||||||
|
case p.c.ConnectionProtocol == TLS:
|
||||||
|
return p.dialTLS(ctx, tlsHostAndPort)
|
||||||
|
case p.c.ConnectionProtocol == StartTLS:
|
||||||
|
return p.dialStartTLS(ctx, startTLSHostAndPort)
|
||||||
|
default:
|
||||||
|
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol"))
|
||||||
}
|
}
|
||||||
return p.dialTLS(ctx, hostAndPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// dialTLS is the default implementation of the Dialer, used when Dialer is nil.
|
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS.
|
||||||
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
|
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
|
||||||
// so we implement it ourselves, heavily inspired by ldap.DialURL.
|
// so we implement it ourselves, heavily inspired by ldap.DialURL.
|
||||||
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
|
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||||
var rootCAs *x509.CertPool
|
tlsConfig, err := p.tlsConfig()
|
||||||
if p.c.CABundle != nil {
|
if err != nil {
|
||||||
rootCAs = x509.NewCertPool()
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
if !rootCAs.AppendCertsFromPEM(p.c.CABundle) {
|
|
||||||
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dialer := &tls.Dialer{Config: &tls.Config{
|
dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig}
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
RootCAs: rootCAs,
|
|
||||||
}}
|
|
||||||
|
|
||||||
c, err := dialer.DialContext(ctx, "tcp", hostAndPort)
|
c, err := dialer.DialContext(ctx, "tcp", hostAndPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
@ -174,6 +189,52 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error
|
|||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS.
|
||||||
|
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
|
||||||
|
// so we implement it ourselves, heavily inspired by ldap.DialURL.
|
||||||
|
func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||||
|
tlsConfig, err := p.tlsConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
host, err := hostWithoutPort(hostAndPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
|
}
|
||||||
|
// Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS.
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
|
||||||
|
c, err := netDialer().DialContext(ctx, "tcp", hostAndPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := ldap.NewConn(c, false)
|
||||||
|
conn.Start()
|
||||||
|
err = conn.StartTLS(tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func netDialer() *net.Dialer {
|
||||||
|
return &net.Dialer{Timeout: time.Minute}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) tlsConfig() (*tls.Config, error) {
|
||||||
|
var rootCAs *x509.CertPool
|
||||||
|
if p.c.CABundle != nil {
|
||||||
|
rootCAs = x509.NewCertPool()
|
||||||
|
if !rootCAs.AppendCertsFromPEM(p.c.CABundle) {
|
||||||
|
return nil, fmt.Errorf("could not parse CA bundle")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Adds the default port if hostAndPort did not already include a port.
|
// Adds the default port if hostAndPort did not already include a port.
|
||||||
func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) {
|
func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) {
|
||||||
host, port, err := net.SplitHostPort(hostAndPort)
|
host, port, err := net.SplitHostPort(hostAndPort)
|
||||||
@ -188,7 +249,7 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string,
|
|||||||
switch {
|
switch {
|
||||||
case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
|
case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
|
||||||
// don't add extra square brackets to an IPv6 address that already has them
|
// don't add extra square brackets to an IPv6 address that already has them
|
||||||
return host + ":" + port, nil
|
return fmt.Sprintf("%s:%s", host, port), nil
|
||||||
case port != "":
|
case port != "":
|
||||||
return net.JoinHostPort(host, port), nil
|
return net.JoinHostPort(host, port), nil
|
||||||
default:
|
default:
|
||||||
@ -196,6 +257,22 @@ func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strip the port from a host or host:port.
|
||||||
|
func hostWithoutPort(hostAndPort string) (string, error) {
|
||||||
|
host, _, err := net.SplitHostPort(hostAndPort)
|
||||||
|
if err != nil {
|
||||||
|
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
|
||||||
|
return hostAndPort, nil
|
||||||
|
}
|
||||||
|
return "", err // hostAndPort argument was not parsable
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(hostAndPort, "[") {
|
||||||
|
// it was an IPv6 address, so preserve the square brackets.
|
||||||
|
return fmt.Sprintf("[%s]", host), nil
|
||||||
|
}
|
||||||
|
return host, nil
|
||||||
|
}
|
||||||
|
|
||||||
// A name for this upstream provider.
|
// A name for this upstream provider.
|
||||||
func (p *Provider) GetName() string {
|
func (p *Provider) GetName() string {
|
||||||
return p.c.Name
|
return p.c.Name
|
||||||
|
@ -1132,27 +1132,55 @@ func TestRealTLSDialing(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
|
connProto LDAPConnectionProtocol
|
||||||
caBundle []byte
|
caBundle []byte
|
||||||
context context.Context
|
context context.Context
|
||||||
wantError string
|
wantError string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path",
|
name: "happy path",
|
||||||
host: testServerHostAndPort,
|
host: testServerHostAndPort,
|
||||||
caBundle: []byte(testServerCABundle),
|
caBundle: []byte(testServerCABundle),
|
||||||
context: context.Background(),
|
connProto: TLS,
|
||||||
|
context: context.Background(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid CA bundle",
|
name: "invalid CA bundle with TLS",
|
||||||
host: testServerHostAndPort,
|
host: testServerHostAndPort,
|
||||||
caBundle: []byte("not a ca bundle"),
|
caBundle: []byte("not a ca bundle"),
|
||||||
|
connProto: TLS,
|
||||||
context: context.Background(),
|
context: context.Background(),
|
||||||
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`,
|
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "invalid CA bundle with StartTLS",
|
||||||
|
host: testServerHostAndPort,
|
||||||
|
caBundle: []byte("not a ca bundle"),
|
||||||
|
connProto: StartTLS,
|
||||||
|
context: context.Background(),
|
||||||
|
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid host with TLS",
|
||||||
|
host: "this:is:not:a:valid:hostname",
|
||||||
|
caBundle: []byte(testServerCABundle),
|
||||||
|
connProto: TLS,
|
||||||
|
context: context.Background(),
|
||||||
|
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid host with StartTLS",
|
||||||
|
host: "this:is:not:a:valid:hostname",
|
||||||
|
caBundle: []byte(testServerCABundle),
|
||||||
|
connProto: StartTLS,
|
||||||
|
context: context.Background(),
|
||||||
|
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "missing CA bundle when it is required because the host is not using a trusted CA",
|
name: "missing CA bundle when it is required because the host is not using a trusted CA",
|
||||||
host: testServerHostAndPort,
|
host: testServerHostAndPort,
|
||||||
caBundle: nil,
|
caBundle: nil,
|
||||||
|
connProto: TLS,
|
||||||
context: context.Background(),
|
context: context.Background(),
|
||||||
wantError: `LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`,
|
wantError: `LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`,
|
||||||
},
|
},
|
||||||
@ -1161,6 +1189,7 @@ func TestRealTLSDialing(t *testing.T) {
|
|||||||
// This is assuming that this port was not reclaimed by another app since the test setup ran. Seems safe enough.
|
// This is assuming that this port was not reclaimed by another app since the test setup ran. Seems safe enough.
|
||||||
host: recentlyClaimedHostAndPort,
|
host: recentlyClaimedHostAndPort,
|
||||||
caBundle: []byte(testServerCABundle),
|
caBundle: []byte(testServerCABundle),
|
||||||
|
connProto: TLS,
|
||||||
context: context.Background(),
|
context: context.Background(),
|
||||||
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort),
|
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort),
|
||||||
},
|
},
|
||||||
@ -1168,25 +1197,35 @@ func TestRealTLSDialing(t *testing.T) {
|
|||||||
name: "pays attention to the passed context",
|
name: "pays attention to the passed context",
|
||||||
host: testServerHostAndPort,
|
host: testServerHostAndPort,
|
||||||
caBundle: []byte(testServerCABundle),
|
caBundle: []byte(testServerCABundle),
|
||||||
|
connProto: TLS,
|
||||||
context: alreadyCancelledContext,
|
context: alreadyCancelledContext,
|
||||||
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort),
|
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "unsupported connection protocol",
|
||||||
|
host: testServerHostAndPort,
|
||||||
|
caBundle: []byte(testServerCABundle),
|
||||||
|
connProto: "bad usage of this type",
|
||||||
|
context: alreadyCancelledContext,
|
||||||
|
wantError: `LDAP Result Code 200 "Network Error": did not specify valid ConnectionProtocol`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
tt := test
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
provider := New(ProviderConfig{
|
provider := New(ProviderConfig{
|
||||||
Host: test.host,
|
Host: tt.host,
|
||||||
CABundle: test.caBundle,
|
CABundle: tt.caBundle,
|
||||||
Dialer: nil, // this test is for the default (production) dialer
|
ConnectionProtocol: tt.connProto,
|
||||||
|
Dialer: nil, // this test is for the default (production) TLS dialer
|
||||||
})
|
})
|
||||||
conn, err := provider.dial(test.context)
|
conn, err := provider.dial(tt.context)
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
}
|
}
|
||||||
if test.wantError != "" {
|
if tt.wantError != "" {
|
||||||
require.Nil(t, conn)
|
require.Nil(t, conn)
|
||||||
require.EqualError(t, err, test.wantError)
|
require.EqualError(t, err, tt.wantError)
|
||||||
} else {
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
@ -1231,6 +1270,12 @@ func TestHostAndPortWithDefaultPort(t *testing.T) {
|
|||||||
defaultPort: "",
|
defaultPort: "",
|
||||||
wantHostAndPort: "host.example.com",
|
wantHostAndPort: "host.example.com",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "host has port and default port is empty",
|
||||||
|
hostAndPort: "host.example.com:42",
|
||||||
|
defaultPort: "",
|
||||||
|
wantHostAndPort: "host.example.com:42",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 host already has port",
|
name: "IPv6 host already has port",
|
||||||
hostAndPort: "[::1%lo0]:80",
|
hostAndPort: "[::1%lo0]:80",
|
||||||
@ -1257,15 +1302,63 @@ func TestHostAndPortWithDefaultPort(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
tt := test
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
hostAndPort, err := hostAndPortWithDefaultPort(test.hostAndPort, test.defaultPort)
|
hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort)
|
||||||
if test.wantError != "" {
|
if tt.wantError != "" {
|
||||||
require.EqualError(t, err, test.wantError)
|
require.EqualError(t, err, tt.wantError)
|
||||||
} else {
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
require.Equal(t, test.wantHostAndPort, hostAndPort)
|
require.Equal(t, tt.wantHostAndPort, hostAndPort)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test various cases of host and port parsing.
|
||||||
|
func TestHostWithoutPort(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hostAndPort string
|
||||||
|
wantError string
|
||||||
|
wantHostAndPort string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "host already has port",
|
||||||
|
hostAndPort: "host.example.com:99",
|
||||||
|
wantHostAndPort: "host.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host does not have port",
|
||||||
|
hostAndPort: "host.example.com",
|
||||||
|
wantHostAndPort: "host.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 host already has port",
|
||||||
|
hostAndPort: "[::1%lo0]:80",
|
||||||
|
wantHostAndPort: "[::1%lo0]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 host does not have port",
|
||||||
|
hostAndPort: "[::1%lo0]",
|
||||||
|
wantHostAndPort: "[::1%lo0]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host is not valid",
|
||||||
|
hostAndPort: "host.example.com:port1:port2",
|
||||||
|
wantError: "address host.example.com:port1:port2: too many colons in address",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
tt := test
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hostAndPort, err := hostWithoutPort(tt.hostAndPort)
|
||||||
|
if tt.wantError != "" {
|
||||||
|
require.EqualError(t, err, tt.wantError)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
require.Equal(t, tt.wantHostAndPort, hostAndPort)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,15 +37,19 @@ func TestLDAPSearch(t *testing.T) {
|
|||||||
cancelFunc() // this will send SIGKILL to the subprocess, just in case
|
cancelFunc() // this will send SIGKILL to the subprocess, just in case
|
||||||
})
|
})
|
||||||
|
|
||||||
hostPorts := findRecentlyUnusedLocalhostPorts(t, 2)
|
localhostPorts := findRecentlyUnusedLocalhostPorts(t, 3)
|
||||||
ldapHostPort := hostPorts[0]
|
ldapLocalhostPort := localhostPorts[0]
|
||||||
unusedHostPort := hostPorts[1]
|
ldapsLocalhostPort := localhostPorts[1]
|
||||||
|
unusedLocalhostPort := localhostPorts[2]
|
||||||
|
|
||||||
// Expose the the test LDAP server's TLS port on the localhost.
|
// Expose the the test LDAP server's TLS port on the localhost.
|
||||||
startKubectlPortForward(ctx, t, ldapHostPort, "ldaps", "ldap", env.ToolsNamespace)
|
startKubectlPortForward(ctx, t, ldapsLocalhostPort, "ldaps", "ldap", env.ToolsNamespace)
|
||||||
|
|
||||||
|
// Expose the the test LDAP server's StartTLS port on the localhost.
|
||||||
|
startKubectlPortForward(ctx, t, ldapLocalhostPort, "ldap", "ldap", env.ToolsNamespace)
|
||||||
|
|
||||||
providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig {
|
providerConfig := func(editFunc func(p *upstreamldap.ProviderConfig)) *upstreamldap.ProviderConfig {
|
||||||
providerConfig := defaultProviderConfig(env, ldapHostPort)
|
providerConfig := defaultProviderConfig(env, ldapsLocalhostPort)
|
||||||
if editFunc != nil {
|
if editFunc != nil {
|
||||||
editFunc(providerConfig)
|
editFunc(providerConfig)
|
||||||
}
|
}
|
||||||
@ -64,7 +68,7 @@ func TestLDAPSearch(t *testing.T) {
|
|||||||
wantUnauthenticated bool
|
wantUnauthenticated bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path",
|
name: "happy path with TLS",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
password: pinnyPassword,
|
password: pinnyPassword,
|
||||||
provider: upstreamldap.New(*providerConfig(nil)),
|
provider: upstreamldap.New(*providerConfig(nil)),
|
||||||
@ -72,6 +76,18 @@ func TestLDAPSearch(t *testing.T) {
|
|||||||
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}},
|
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "happy path with StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
|
||||||
|
p.Host = "127.0.0.1:" + ldapLocalhostPort
|
||||||
|
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||||
|
})),
|
||||||
|
wantAuthResponse: &authenticator.Response{
|
||||||
|
User: &user.DefaultInfo{Name: "pinny", UID: "1000", Groups: []string{"ball-game-players", "seals"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "using a different user search base",
|
name: "using a different user search base",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
@ -251,6 +267,17 @@ func TestLDAPSearch(t *testing.T) {
|
|||||||
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })),
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })),
|
||||||
wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
|
wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "when the bind user username is wrong with StartTLS: example of an error after successful connection with StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
|
||||||
|
p.Host = "127.0.0.1:" + ldapLocalhostPort
|
||||||
|
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||||
|
p.BindUsername = "cn=wrong,dc=pinniped,dc=dev"
|
||||||
|
})),
|
||||||
|
wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "when the end user password is wrong",
|
name: "when the end user password is wrong",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
@ -296,32 +323,89 @@ func TestLDAPSearch(t *testing.T) {
|
|||||||
wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `,
|
wantError: `error searching for user "pinny": LDAP Result Code 4 "Size Limit Exceeded": `,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "when the server is unreachable",
|
name: "when the server is unreachable with TLS",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
password: pinnyPassword,
|
password: pinnyPassword,
|
||||||
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedHostPort })),
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedLocalhostPort })),
|
||||||
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedHostPort, unusedHostPort),
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "when the server is not parsable",
|
name: "when the server is unreachable with StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
|
||||||
|
p.Host = "127.0.0.1:" + unusedLocalhostPort
|
||||||
|
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||||
|
})),
|
||||||
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when the server is not parsable with TLS",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
password: pinnyPassword,
|
password: pinnyPassword,
|
||||||
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
|
||||||
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
|
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "when the CA bundle is not parsable",
|
name: "when the server is not parsable with StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
|
||||||
|
p.Host = "127.0.0.1:" + ldapLocalhostPort
|
||||||
|
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||||
|
p.Host = "too:many:ports"
|
||||||
|
})),
|
||||||
|
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when the CA bundle is not parsable with TLS",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
password: pinnyPassword,
|
password: pinnyPassword,
|
||||||
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })),
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })),
|
||||||
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapHostPort),
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapsLocalhostPort),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "when the CA bundle does not cause the host to be trusted",
|
name: "when the CA bundle is not parsable with StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
|
||||||
|
p.Host = "127.0.0.1:" + ldapLocalhostPort
|
||||||
|
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||||
|
p.CABundle = []byte("invalid-pem")
|
||||||
|
})),
|
||||||
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapLocalhostPort),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when the CA bundle does not cause the host to be trusted with TLS",
|
||||||
username: "pinny",
|
username: "pinny",
|
||||||
password: pinnyPassword,
|
password: pinnyPassword,
|
||||||
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })),
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })),
|
||||||
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapHostPort),
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`, ldapsLocalhostPort),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when the CA bundle does not cause the host to be trusted with StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
|
||||||
|
p.Host = "127.0.0.1:" + ldapLocalhostPort
|
||||||
|
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||||
|
p.CABundle = nil
|
||||||
|
})),
|
||||||
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": TLS handshake failed (x509: certificate signed by unknown authority)`, ldapLocalhostPort),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when trying to use TLS to connect to a port which only supports StartTLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + ldapLocalhostPort })),
|
||||||
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": EOF`, ldapLocalhostPort),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when trying to use StartTLS to connect to a port which only supports TLS",
|
||||||
|
username: "pinny",
|
||||||
|
password: pinnyPassword,
|
||||||
|
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.ConnectionProtocol = upstreamldap.StartTLS })),
|
||||||
|
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": unable to read LDAP response packet: unexpected EOF`, ldapsLocalhostPort),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "when the UsernameAttribute attribute has multiple values in the entry",
|
name: "when the UsernameAttribute attribute has multiple values in the entry",
|
||||||
@ -541,13 +625,14 @@ type authUserResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultProviderConfig(env *library.TestEnv, ldapHostPort string) *upstreamldap.ProviderConfig {
|
func defaultProviderConfig(env *library.TestEnv, port string) *upstreamldap.ProviderConfig {
|
||||||
return &upstreamldap.ProviderConfig{
|
return &upstreamldap.ProviderConfig{
|
||||||
Name: "test-ldap-provider",
|
Name: "test-ldap-provider",
|
||||||
Host: "127.0.0.1:" + ldapHostPort,
|
Host: "127.0.0.1:" + port,
|
||||||
CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle),
|
ConnectionProtocol: upstreamldap.TLS,
|
||||||
BindUsername: "cn=admin,dc=pinniped,dc=dev",
|
CABundle: []byte(env.SupervisorUpstreamLDAP.CABundle),
|
||||||
BindPassword: "password",
|
BindUsername: "cn=admin,dc=pinniped,dc=dev",
|
||||||
|
BindPassword: "password",
|
||||||
UserSearch: upstreamldap.UserSearchConfig{
|
UserSearch: upstreamldap.UserSearchConfig{
|
||||||
Base: "ou=users,dc=pinniped,dc=dev",
|
Base: "ou=users,dc=pinniped,dc=dev",
|
||||||
Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case
|
Filter: "", // defaults to UsernameAttribute={}, i.e. "cn={}" in this case
|
||||||
|
Loading…
Reference in New Issue
Block a user