Convert LDAP code to use endpointaddr package.
Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
d9a3992b3b
commit
89eff28549
@ -26,6 +26,7 @@ import (
|
||||
pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions"
|
||||
"go.pinniped.dev/internal/certauthority"
|
||||
"go.pinniped.dev/internal/controllerlib"
|
||||
"go.pinniped.dev/internal/endpointaddr"
|
||||
"go.pinniped.dev/internal/mocks/mockldapconn"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
@ -871,9 +872,9 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
|
||||
tt.setupMocks(conn)
|
||||
}
|
||||
|
||||
dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) {
|
||||
dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (upstreamldap.Conn, error) {
|
||||
if tt.dialErrors != nil {
|
||||
dialErr := tt.dialErrors[hostAndPort]
|
||||
dialErr := tt.dialErrors[addr.Endpoint()]
|
||||
if dialErr != nil {
|
||||
return nil, dialErr
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"k8s.io/apiserver/pkg/authentication/authenticator"
|
||||
"k8s.io/apiserver/pkg/authentication/user"
|
||||
|
||||
"go.pinniped.dev/internal/endpointaddr"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
@ -30,6 +31,8 @@ const (
|
||||
commonNameAttributeName = "cn"
|
||||
searchFilterInterpolationLocationMarker = "{}"
|
||||
groupSearchPageSize = uint32(250)
|
||||
defaultLDAPPort = uint16(389)
|
||||
defaultLDAPSPort = uint16(636)
|
||||
)
|
||||
|
||||
// Conn abstracts the upstream LDAP communication protocol (mostly for testing).
|
||||
@ -48,16 +51,16 @@ var _ Conn = &ldap.Conn{}
|
||||
|
||||
// LDAPDialer is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP.
|
||||
type LDAPDialer interface {
|
||||
Dial(ctx context.Context, hostAndPort string) (Conn, error)
|
||||
Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error)
|
||||
}
|
||||
|
||||
// LDAPDialerFunc makes it easy to use a func as an LDAPDialer.
|
||||
type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error)
|
||||
type LDAPDialerFunc func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error)
|
||||
|
||||
var _ LDAPDialer = LDAPDialerFunc(nil)
|
||||
|
||||
func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||
return f(ctx, hostAndPort)
|
||||
func (f LDAPDialerFunc) Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
return f(ctx, addr)
|
||||
}
|
||||
|
||||
type LDAPConnectionProtocol string
|
||||
@ -147,26 +150,26 @@ func (p *Provider) GetConfig() ProviderConfig {
|
||||
}
|
||||
|
||||
func (p *Provider) dial(ctx context.Context) (Conn, error) {
|
||||
tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort)
|
||||
tlsAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPSPort)
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
|
||||
startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort)
|
||||
startTLSAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPPort)
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
|
||||
// Choose how and where to dial based on TLS vs. StartTLS config option.
|
||||
var dialFunc LDAPDialerFunc
|
||||
var hostAndPort string
|
||||
var addr endpointaddr.HostPort
|
||||
switch {
|
||||
case p.c.ConnectionProtocol == TLS:
|
||||
dialFunc = p.dialTLS
|
||||
hostAndPort = tlsHostAndPort
|
||||
addr = tlsAddr
|
||||
case p.c.ConnectionProtocol == StartTLS:
|
||||
dialFunc = p.dialStartTLS
|
||||
hostAndPort = startTLSHostAndPort
|
||||
addr = startTLSAddr
|
||||
default:
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol"))
|
||||
}
|
||||
@ -176,20 +179,20 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) {
|
||||
dialFunc = p.c.Dialer.Dial
|
||||
}
|
||||
|
||||
return dialFunc(ctx, hostAndPort)
|
||||
return dialFunc(ctx, addr)
|
||||
}
|
||||
|
||||
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS.
|
||||
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
|
||||
// so we implement it ourselves, heavily inspired by ldap.DialURL.
|
||||
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||
func (p *Provider) dialTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
tlsConfig, err := p.tlsConfig()
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
|
||||
dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig}
|
||||
c, err := dialer.DialContext(ctx, "tcp", hostAndPort)
|
||||
c, err := dialer.DialContext(ctx, "tcp", addr.Endpoint())
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
@ -202,20 +205,16 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error
|
||||
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS.
|
||||
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
|
||||
// so we implement it ourselves, heavily inspired by ldap.DialURL.
|
||||
func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||
func (p *Provider) dialStartTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
tlsConfig, err := p.tlsConfig()
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
|
||||
host, err := hostWithoutPort(hostAndPort)
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
// Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS.
|
||||
tlsConfig.ServerName = host
|
||||
tlsConfig.ServerName = addr.Host
|
||||
|
||||
c, err := netDialer().DialContext(ctx, "tcp", hostAndPort)
|
||||
c, err := netDialer().DialContext(ctx, "tcp", addr.Endpoint())
|
||||
if err != nil {
|
||||
return nil, ldap.NewError(ldap.ErrorNetwork, err)
|
||||
}
|
||||
@ -245,44 +244,6 @@ func (p *Provider) tlsConfig() (*tls.Config, error) {
|
||||
return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil
|
||||
}
|
||||
|
||||
// Adds the default port if hostAndPort did not already include a port.
|
||||
func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) {
|
||||
host, port, err := net.SplitHostPort(hostAndPort)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
|
||||
host = hostAndPort
|
||||
port = defaultPort
|
||||
} else {
|
||||
return "", err // hostAndPort argument was not parsable
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
|
||||
// don't add extra square brackets to an IPv6 address that already has them
|
||||
return fmt.Sprintf("%s:%s", host, port), nil
|
||||
case port != "":
|
||||
return net.JoinHostPort(host, port), nil
|
||||
default:
|
||||
return host, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Strip the port from a host or host:port.
|
||||
func hostWithoutPort(hostAndPort string) (string, error) {
|
||||
host, _, err := net.SplitHostPort(hostAndPort)
|
||||
if err != nil {
|
||||
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
|
||||
return hostAndPort, nil
|
||||
}
|
||||
return "", err // hostAndPort argument was not parsable
|
||||
}
|
||||
if strings.HasPrefix(hostAndPort, "[") {
|
||||
// it was an IPv6 address, so preserve the square brackets.
|
||||
return fmt.Sprintf("[%s]", host), nil
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// A name for this upstream provider.
|
||||
func (p *Provider) GetName() string {
|
||||
return p.c.Name
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"k8s.io/apiserver/pkg/authentication/user"
|
||||
|
||||
"go.pinniped.dev/internal/certauthority"
|
||||
"go.pinniped.dev/internal/endpointaddr"
|
||||
"go.pinniped.dev/internal/mocks/mockldapconn"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
)
|
||||
@ -926,9 +927,9 @@ func TestEndUserAuthentication(t *testing.T) {
|
||||
}
|
||||
|
||||
dialWasAttempted := false
|
||||
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
dialWasAttempted = true
|
||||
require.Equal(t, tt.providerConfig.Host, hostAndPort)
|
||||
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
|
||||
if tt.dialError != nil {
|
||||
return nil, tt.dialError
|
||||
}
|
||||
@ -1061,9 +1062,9 @@ func TestTestConnection(t *testing.T) {
|
||||
}
|
||||
|
||||
dialWasAttempted := false
|
||||
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) {
|
||||
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
|
||||
dialWasAttempted = true
|
||||
require.Equal(t, tt.providerConfig.Host, hostAndPort)
|
||||
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
|
||||
if tt.dialError != nil {
|
||||
return nil, tt.dialError
|
||||
}
|
||||
@ -1185,7 +1186,7 @@ func TestRealTLSDialing(t *testing.T) {
|
||||
caBundle: []byte(testServerCABundle),
|
||||
connProto: TLS,
|
||||
context: context.Background(),
|
||||
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
|
||||
wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`,
|
||||
},
|
||||
{
|
||||
name: "invalid host with StartTLS",
|
||||
@ -1193,7 +1194,7 @@ func TestRealTLSDialing(t *testing.T) {
|
||||
caBundle: []byte(testServerCABundle),
|
||||
connProto: StartTLS,
|
||||
context: context.Background(),
|
||||
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
|
||||
wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`,
|
||||
},
|
||||
{
|
||||
name: "missing CA bundle when it is required because the host is not using a trusted CA",
|
||||
@ -1261,123 +1262,3 @@ func TestRealTLSDialing(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test various cases of host and port parsing.
|
||||
func TestHostAndPortWithDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostAndPort string
|
||||
defaultPort string
|
||||
wantError string
|
||||
wantHostAndPort string
|
||||
}{
|
||||
{
|
||||
name: "host already has port",
|
||||
hostAndPort: "host.example.com:99",
|
||||
defaultPort: "42",
|
||||
wantHostAndPort: "host.example.com:99",
|
||||
},
|
||||
{
|
||||
name: "host does not have port",
|
||||
hostAndPort: "host.example.com",
|
||||
defaultPort: "42",
|
||||
wantHostAndPort: "host.example.com:42",
|
||||
},
|
||||
{
|
||||
name: "host does not have port and default port is empty",
|
||||
hostAndPort: "host.example.com",
|
||||
defaultPort: "",
|
||||
wantHostAndPort: "host.example.com",
|
||||
},
|
||||
{
|
||||
name: "host has port and default port is empty",
|
||||
hostAndPort: "host.example.com:42",
|
||||
defaultPort: "",
|
||||
wantHostAndPort: "host.example.com:42",
|
||||
},
|
||||
{
|
||||
name: "IPv6 host already has port",
|
||||
hostAndPort: "[::1%lo0]:80",
|
||||
defaultPort: "42",
|
||||
wantHostAndPort: "[::1%lo0]:80",
|
||||
},
|
||||
{
|
||||
name: "IPv6 host does not have port",
|
||||
hostAndPort: "[::1%lo0]",
|
||||
defaultPort: "42",
|
||||
wantHostAndPort: "[::1%lo0]:42",
|
||||
},
|
||||
{
|
||||
name: "IPv6 host does not have port and default port is empty",
|
||||
hostAndPort: "[::1%lo0]",
|
||||
defaultPort: "",
|
||||
wantHostAndPort: "[::1%lo0]",
|
||||
},
|
||||
{
|
||||
name: "host is not valid",
|
||||
hostAndPort: "host.example.com:port1:port2",
|
||||
defaultPort: "42",
|
||||
wantError: "address host.example.com:port1:port2: too many colons in address",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
tt := test
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort)
|
||||
if tt.wantError != "" {
|
||||
require.EqualError(t, err, tt.wantError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, tt.wantHostAndPort, hostAndPort)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test various cases of host and port parsing.
|
||||
func TestHostWithoutPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostAndPort string
|
||||
wantError string
|
||||
wantHostAndPort string
|
||||
}{
|
||||
{
|
||||
name: "host already has port",
|
||||
hostAndPort: "host.example.com:99",
|
||||
wantHostAndPort: "host.example.com",
|
||||
},
|
||||
{
|
||||
name: "host does not have port",
|
||||
hostAndPort: "host.example.com",
|
||||
wantHostAndPort: "host.example.com",
|
||||
},
|
||||
{
|
||||
name: "IPv6 host already has port",
|
||||
hostAndPort: "[::1%lo0]:80",
|
||||
wantHostAndPort: "[::1%lo0]",
|
||||
},
|
||||
{
|
||||
name: "IPv6 host does not have port",
|
||||
hostAndPort: "[::1%lo0]",
|
||||
wantHostAndPort: "[::1%lo0]",
|
||||
},
|
||||
{
|
||||
name: "host is not valid",
|
||||
hostAndPort: "host.example.com:port1:port2",
|
||||
wantError: "address host.example.com:port1:port2: too many colons in address",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
tt := test
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hostAndPort, err := hostWithoutPort(tt.hostAndPort)
|
||||
if tt.wantError != "" {
|
||||
require.EqualError(t, err, tt.wantError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, tt.wantHostAndPort, hostAndPort)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -344,7 +344,7 @@ func TestLDAPSearch(t *testing.T) {
|
||||
username: "pinny",
|
||||
password: pinnyPassword,
|
||||
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
|
||||
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
|
||||
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`,
|
||||
},
|
||||
{
|
||||
name: "when the server is not parsable with StartTLS",
|
||||
@ -355,7 +355,7 @@ func TestLDAPSearch(t *testing.T) {
|
||||
p.ConnectionProtocol = upstreamldap.StartTLS
|
||||
p.Host = "too:many:ports"
|
||||
})),
|
||||
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
|
||||
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`,
|
||||
},
|
||||
{
|
||||
name: "when the CA bundle is not parsable with TLS",
|
||||
|
Loading…
Reference in New Issue
Block a user