diff --git a/go.mod b/go.mod index 8004534f..82b6a2a0 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/go-openapi/spec v0.20.3 // indirect github.com/gofrs/flock v0.8.0 github.com/golang/mock v1.5.0 - github.com/google/go-cmp v0.5.5 + github.com/google/go-cmp v0.5.6 github.com/google/gofuzz v1.2.0 github.com/gorilla/securecookie v1.1.1 github.com/gorilla/websocket v1.4.2 diff --git a/go.sum b/go.sum index 30a8df74..28caead8 100644 --- a/go.sum +++ b/go.sum @@ -514,8 +514,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-jsonnet v0.16.0/go.mod h1:sOcuej3UW1vpPTZOr8L7RQimqai1a57bt5j22LzGZCw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= diff --git a/internal/controller/impersonatorconfig/impersonator_config.go b/internal/controller/impersonatorconfig/impersonator_config.go index 2e9b1861..fbd55608 100644 --- a/internal/controller/impersonatorconfig/impersonator_config.go +++ b/internal/controller/impersonatorconfig/impersonator_config.go @@ -12,7 +12,6 @@ import ( "fmt" "net" "reflect" - "strconv" "strings" "time" @@ -41,6 +40,7 @@ import ( "go.pinniped.dev/internal/controller/issuerconfig" "go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/dynamiccert" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/plog" ) @@ -760,13 +760,13 @@ func (c *impersonatorConfigController) findDesiredTLSCertificateName(config *v1a } func (c *impersonatorConfigController) findTLSCertificateNameFromEndpointConfig(config *v1alpha1.ImpersonationProxySpec) *certNameInfo { - endpointMaybeWithPort := config.ExternalEndpoint - endpointWithoutPort := strings.Split(endpointMaybeWithPort, ":")[0] - parsedAsIP := net.ParseIP(endpointWithoutPort) - if parsedAsIP != nil { - return &certNameInfo{ready: true, selectedIPs: []net.IP{parsedAsIP}, clientEndpoint: endpointMaybeWithPort} + addr, _ := endpointaddr.Parse(config.ExternalEndpoint, 443) + endpoint := strings.TrimSuffix(addr.Endpoint(), ":443") + + if ip := net.ParseIP(addr.Host); ip != nil { + return &certNameInfo{ready: true, selectedIPs: []net.IP{ip}, clientEndpoint: endpoint} } - return &certNameInfo{ready: true, selectedHostname: endpointWithoutPort, clientEndpoint: endpointMaybeWithPort} + return &certNameInfo{ready: true, selectedHostname: addr.Host, clientEndpoint: endpoint} } func (c *impersonatorConfigController) findTLSCertificateNameFromLoadBalancer() (*certNameInfo, error) { @@ -1021,46 +1021,11 @@ func validateCredentialIssuerSpec(spec *v1alpha1.ImpersonationProxySpec) error { return fmt.Errorf("externalEndpoint must be set when service.type is None") } - if err := validateExternalEndpoint(spec.ExternalEndpoint); err != nil { - return fmt.Errorf("invalid ExternalEndpoint %q: %w", spec.ExternalEndpoint, err) + if spec.ExternalEndpoint != "" { + if _, err := endpointaddr.Parse(spec.ExternalEndpoint, 443); err != nil { + return fmt.Errorf("invalid ExternalEndpoint %q: %w", spec.ExternalEndpoint, err) + } } return nil } - -func validateExternalEndpoint(endpoint string) error { - // Empty string is valid (no external endpoint, default to service name) - if endpoint == "" { - return nil - } - - // Try parsing it both with and without an implicit port 443 at the end. - host, port, err := net.SplitHostPort(endpoint) - - // If we got an error parsing the raw input, try adding an implicit port 443. - if err != nil { - host, port, err = net.SplitHostPort(net.JoinHostPort(endpoint, "443")) - } - - // If there's still an error, fail now. - if err != nil { - return err - } - - portInt, _ := strconv.Atoi(port) - if len(validation.IsValidPortNum(portInt)) > 0 { - return fmt.Errorf("invalid port %q", port) - } - - // Check if the host part is a valid IP address. - if len(validation.IsValidIP(host)) == 0 { - return nil - } - - // Check if the host part is a valid hostname according to RFC 1123. - if len(validation.IsDNS1123Subdomain(host)) == 0 { - return nil - } - - return fmt.Errorf("host %q is not a valid hostname or IP address", host) -} diff --git a/internal/controller/impersonatorconfig/impersonator_config_test.go b/internal/controller/impersonatorconfig/impersonator_config_test.go index f188a4cf..ec4840a9 100644 --- a/internal/controller/impersonatorconfig/impersonator_config_test.go +++ b/internal/controller/impersonatorconfig/impersonator_config_test.go @@ -3577,53 +3577,3 @@ func (q *testQueue) AddRateLimited(key controllerlib.Key) { q.key = key } - -func TestValidateExternalEndpoint(t *testing.T) { - t.Parallel() - - for _, tt := range []struct { - input string - expectErr string - }{ - {input: ""}, - {input: "127.0.0.1"}, - {input: "127.0.0.1:8443"}, - {input: "[127.0.0.1]:8443"}, - {input: "2001:db8::ffff"}, - {input: "[2001:db8::ffff]:8443"}, - {input: "host.example.com"}, - {input: "host-dev.example.com"}, - {input: "host.example.com:8443"}, - {input: "[host.example.com]:8443"}, - { - input: "https://host.example.com", - expectErr: `invalid port "//host.example.com"`, - }, - { - input: "host.example.com/some/path", - expectErr: `host "host.example.com/some/path" is not a valid hostname or IP address`, - }, - { - input: "[host.example.com", - expectErr: "address [host.example.com:443: missing ']' in address", - }, - { - input: "___.example.com:1234", - expectErr: `host "___.example.com" is not a valid hostname or IP address`, - }, - { - input: "HOST.EXAMPLE.COM", - expectErr: `host "HOST.EXAMPLE.COM" is not a valid hostname or IP address`, - }, - } { - tt := tt - t.Run(fmt.Sprintf("parse %q", tt.input), func(t *testing.T) { - got := validateExternalEndpoint(tt.input) - if tt.expectErr == "" { - require.NoError(t, got) - } else { - require.EqualError(t, got, tt.expectErr) - } - }) - } -} diff --git a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go index 6eaab2f3..c801e741 100644 --- a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go @@ -26,6 +26,7 @@ import ( pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions" "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" @@ -871,9 +872,9 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { tt.setupMocks(conn) } - dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) { + dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (upstreamldap.Conn, error) { if tt.dialErrors != nil { - dialErr := tt.dialErrors[hostAndPort] + dialErr := tt.dialErrors[addr.Endpoint()] if dialErr != nil { return nil, dialErr } diff --git a/internal/endpointaddr/endpointaddr.go b/internal/endpointaddr/endpointaddr.go new file mode 100644 index 00000000..fd927765 --- /dev/null +++ b/internal/endpointaddr/endpointaddr.go @@ -0,0 +1,71 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package endpointaddr implements parsing and validation of "[:]" strings for Pinniped APIs. +package endpointaddr + +import ( + "fmt" + "net" + "strconv" + + "k8s.io/apimachinery/pkg/util/validation" +) + +type HostPort struct { + // Host is the validated host part of the input, which may be a hostname or IP. + // + // This string can be be used as an x509 certificate SAN. + Host string + + // Port is the validated port number, which may be defaulted. + Port uint16 +} + +// Endpoint is the host:port validated from the input, where port may be a default value. +// +// This string can be passed to net.Dial. +func (h *HostPort) Endpoint() string { + return net.JoinHostPort(h.Host, strconv.Itoa(int(h.Port))) +} + +// Parse an "endpoint address" string, providing a default port. The input can be in several valid formats: +// +// - "" (DNS hostname) +// - "" (IPv4 address) +// - "" (IPv6 address) +// - ":" (DNS hostname with port) +// - ":" (IPv4 address with port) +// - "[]:" (IPv6 address with port, brackets are required) +// +// If the input does not not specify a port number, then defaultPort will be used. +func Parse(endpoint string, defaultPort uint16) (HostPort, error) { + // Try parsing it both with and without an implicit port 443 at the end. + host, port, err := net.SplitHostPort(endpoint) + + // If we got an error parsing the raw input, try adding the default port. + if err != nil { + host, port, err = net.SplitHostPort(net.JoinHostPort(endpoint, strconv.Itoa(int(defaultPort)))) + } + + // Give up if there's still an error splitting the host and port. + if err != nil { + return HostPort{}, err + } + + // Parse the port number is an integer in the range of valid ports. + integerPort, _ := strconv.Atoi(port) + if len(validation.IsValidPortNum(integerPort)) > 0 { + return HostPort{}, fmt.Errorf("invalid port %q", port) + } + + // Check if the host part is a IPv4 or IPv6 address or a valid hostname according to RFC 1123. + switch { + case len(validation.IsValidIP(host)) == 0: + case len(validation.IsDNS1123Subdomain(host)) == 0: + default: + return HostPort{}, fmt.Errorf("host %q is not a valid hostname or IP address", host) + } + + return HostPort{Host: host, Port: uint16(integerPort)}, nil +} diff --git a/internal/endpointaddr/endpointaddr_test.go b/internal/endpointaddr/endpointaddr_test.go new file mode 100644 index 00000000..736df312 --- /dev/null +++ b/internal/endpointaddr/endpointaddr_test.go @@ -0,0 +1,182 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package endpointaddr + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + t.Parallel() + for _, tt := range []struct { + name string + input string + defaultPort uint16 + expectErr string + expect HostPort + expectEndpoint string + }{ + { + name: "plain IPv4", + input: "127.0.0.1", + defaultPort: 443, + expect: HostPort{Host: "127.0.0.1", Port: 443}, + expectEndpoint: "127.0.0.1:443", + }, + { + name: "IPv4 with port", + input: "127.0.0.1:8443", + defaultPort: 443, + expect: HostPort{Host: "127.0.0.1", Port: 8443}, + expectEndpoint: "127.0.0.1:8443", + }, + { + name: "IPv4 in brackets with port", + input: "[127.0.0.1]:8443", + defaultPort: 443, + expect: HostPort{Host: "127.0.0.1", Port: 8443}, + expectEndpoint: "127.0.0.1:8443", + }, + { + name: "IPv4 as IPv6 in brackets with port", + input: "[::127.0.0.1]:8443", + defaultPort: 443, + expect: HostPort{Host: "::127.0.0.1", Port: 8443}, + expectEndpoint: "[::127.0.0.1]:8443", + }, + { + name: "IPv4 as IPv6 without port", + input: "::127.0.0.1", + defaultPort: 443, + expect: HostPort{Host: "::127.0.0.1", Port: 443}, + expectEndpoint: "[::127.0.0.1]:443", + }, + { + name: "plain IPv6 without port", + input: "2001:db8::ffff", + defaultPort: 443, + expect: HostPort{Host: "2001:db8::ffff", Port: 443}, + expectEndpoint: "[2001:db8::ffff]:443", + }, + { + name: "IPv6 with port", + input: "[2001:db8::ffff]:8443", + defaultPort: 443, + expect: HostPort{Host: "2001:db8::ffff", Port: 8443}, + expectEndpoint: "[2001:db8::ffff]:8443", + }, + { + name: "plain hostname", + input: "host.example.com", + defaultPort: 443, + expect: HostPort{Host: "host.example.com", Port: 443}, + expectEndpoint: "host.example.com:443", + }, + { + name: "plain hostname with dash", + input: "host-dev.example.com", + defaultPort: 443, + expect: HostPort{Host: "host-dev.example.com", Port: 443}, + expectEndpoint: "host-dev.example.com:443", + }, + { + name: "hostname with port", + input: "host.example.com:8443", + defaultPort: 443, + expect: HostPort{Host: "host.example.com", Port: 8443}, + expectEndpoint: "host.example.com:8443", + }, + { + name: "hostname in brackets with port", + input: "[host.example.com]:8443", + defaultPort: 443, + expect: HostPort{Host: "host.example.com", Port: 8443}, + expectEndpoint: "host.example.com:8443", + }, + { + name: "hostname without dots", + input: "localhost", + defaultPort: 443, + expect: HostPort{Host: "localhost", Port: 443}, + expectEndpoint: "localhost:443", + }, + { + name: "hostname and port without dots", + input: "localhost:8443", + defaultPort: 443, + expect: HostPort{Host: "localhost", Port: 8443}, + expectEndpoint: "localhost:8443", + }, + { + name: "invalid empty string", + input: "", + defaultPort: 443, + expectErr: `host "" is not a valid hostname or IP address`, + }, + { + // IPv6 zone index specifiers are not yet supported. + name: "IPv6 with port and zone index", + input: "[2001:db8::ffff%lo0]:8443", + defaultPort: 443, + expectErr: `host "2001:db8::ffff%lo0" is not a valid hostname or IP address`, + }, + { + name: "IPv6 in brackets without port", + input: "[2001:db8::ffff]", + defaultPort: 443, + expectErr: `address [[2001:db8::ffff]]:443: missing port in address`, + }, + { + name: "invalid HTTPS URL", + input: "https://host.example.com", + defaultPort: 443, + expectErr: `invalid port "//host.example.com"`, + }, + { + name: "invalid host with URL path", + input: "host.example.com/some/path", + defaultPort: 443, + expectErr: `host "host.example.com/some/path" is not a valid hostname or IP address`, + }, + { + name: "invalid host with mismatched brackets", + input: "[host.example.com", + defaultPort: 443, + expectErr: "address [host.example.com:443: missing ']' in address", + }, + { + name: "invalid host with underscores", + input: "___.example.com:1234", + defaultPort: 443, + expectErr: `host "___.example.com" is not a valid hostname or IP address`, + }, + { + name: "invalid host with uppercase", + input: "HOST.EXAMPLE.COM", + defaultPort: 443, + expectErr: `host "HOST.EXAMPLE.COM" is not a valid hostname or IP address`, + }, + { + name: "invalid host with extra port", + input: "host.example.com:port1:port2", + defaultPort: 443, + expectErr: `host "host.example.com:port1:port2" is not a valid hostname or IP address`, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := Parse(tt.input, tt.defaultPort) + if tt.expectErr == "" { + assert.NoError(t, err) + assert.Equal(t, tt.expect, got) + assert.Equal(t, tt.expectEndpoint, got.Endpoint()) + } else { + assert.EqualError(t, err, tt.expectErr) + assert.Equal(t, HostPort{}, got) + } + }) + } +} diff --git a/internal/testutil/tlsserver.go b/internal/testutil/tlsserver.go index c6d1a522..13e39324 100644 --- a/internal/testutil/tlsserver.go +++ b/internal/testutil/tlsserver.go @@ -1,13 +1,18 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package testutil import ( + "crypto/tls" "encoding/pem" + "errors" + "net" "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/require" ) // TLSTestServer starts a test server listening on a local port using a test CA. It returns the PEM CA bundle and the @@ -23,3 +28,35 @@ func TLSTestServer(t *testing.T, handler http.HandlerFunc) (caBundlePEM string, })) return caBundle, server.URL } + +func TLSTestServerWithCert(t *testing.T, handler http.HandlerFunc, certificate *tls.Certificate) (url string) { + t.Helper() + + server := http.Server{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{*certificate}, + MinVersion: tls.VersionTLS12, + }, + Handler: handler, + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + serverShutdownChan := make(chan error) + go func() { + // Empty certFile and keyFile will use certs from Server.TLSConfig. + serverShutdownChan <- server.ServeTLS(l, "", "") + }() + + t.Cleanup(func() { + _ = server.Close() + serveErr := <-serverShutdownChan + if !errors.Is(serveErr, http.ErrServerClosed) { + t.Log("Got an unexpected error while starting the fake http server!") + require.NoError(t, serveErr) + } + }) + + return l.Addr().String() +} diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 54c51293..f857c730 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -21,6 +21,7 @@ import ( "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/plog" ) @@ -30,6 +31,8 @@ const ( commonNameAttributeName = "cn" searchFilterInterpolationLocationMarker = "{}" groupSearchPageSize = uint32(250) + defaultLDAPPort = uint16(389) + defaultLDAPSPort = uint16(636) ) // Conn abstracts the upstream LDAP communication protocol (mostly for testing). @@ -48,16 +51,16 @@ var _ Conn = &ldap.Conn{} // LDAPDialer is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP. type LDAPDialer interface { - Dial(ctx context.Context, hostAndPort string) (Conn, error) + Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) } // LDAPDialerFunc makes it easy to use a func as an LDAPDialer. -type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error) +type LDAPDialerFunc func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) var _ LDAPDialer = LDAPDialerFunc(nil) -func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, error) { - return f(ctx, hostAndPort) +func (f LDAPDialerFunc) Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { + return f(ctx, addr) } type LDAPConnectionProtocol string @@ -147,26 +150,26 @@ func (p *Provider) GetConfig() ProviderConfig { } func (p *Provider) dial(ctx context.Context) (Conn, error) { - tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort) + tlsAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPSPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } - startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort) + startTLSAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } // Choose how and where to dial based on TLS vs. StartTLS config option. var dialFunc LDAPDialerFunc - var hostAndPort string + var addr endpointaddr.HostPort switch { case p.c.ConnectionProtocol == TLS: dialFunc = p.dialTLS - hostAndPort = tlsHostAndPort + addr = tlsAddr case p.c.ConnectionProtocol == StartTLS: dialFunc = p.dialStartTLS - hostAndPort = startTLSHostAndPort + addr = startTLSAddr default: return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol")) } @@ -176,20 +179,20 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) { dialFunc = p.c.Dialer.Dial } - return dialFunc(ctx, hostAndPort) + return dialFunc(ctx, addr) } // dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS. // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // so we implement it ourselves, heavily inspired by ldap.DialURL. -func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { +func (p *Provider) dialTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { tlsConfig, err := p.tlsConfig() if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig} - c, err := dialer.DialContext(ctx, "tcp", hostAndPort) + c, err := dialer.DialContext(ctx, "tcp", addr.Endpoint()) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } @@ -202,20 +205,16 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error // dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS. // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // so we implement it ourselves, heavily inspired by ldap.DialURL. -func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) { +func (p *Provider) dialStartTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { tlsConfig, err := p.tlsConfig() if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } - host, err := hostWithoutPort(hostAndPort) - if err != nil { - return nil, ldap.NewError(ldap.ErrorNetwork, err) - } // Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS. - tlsConfig.ServerName = host + tlsConfig.ServerName = addr.Host - c, err := netDialer().DialContext(ctx, "tcp", hostAndPort) + c, err := netDialer().DialContext(ctx, "tcp", addr.Endpoint()) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } @@ -245,44 +244,6 @@ func (p *Provider) tlsConfig() (*tls.Config, error) { return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil } -// Adds the default port if hostAndPort did not already include a port. -func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) { - host, port, err := net.SplitHostPort(hostAndPort) - if err != nil { - if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare - host = hostAndPort - port = defaultPort - } else { - return "", err // hostAndPort argument was not parsable - } - } - switch { - case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"): - // don't add extra square brackets to an IPv6 address that already has them - return fmt.Sprintf("%s:%s", host, port), nil - case port != "": - return net.JoinHostPort(host, port), nil - default: - return host, nil - } -} - -// Strip the port from a host or host:port. -func hostWithoutPort(hostAndPort string) (string, error) { - host, _, err := net.SplitHostPort(hostAndPort) - if err != nil { - if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare - return hostAndPort, nil - } - return "", err // hostAndPort argument was not parsable - } - if strings.HasPrefix(hostAndPort, "[") { - // it was an IPv6 address, so preserve the square brackets. - return fmt.Sprintf("[%s]", host), nil - } - return host, nil -} - // A name for this upstream provider. func (p *Provider) GetName() string { return p.c.Name diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index f66d69ce..fd8b9658 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "testing" + "time" "github.com/go-ldap/ldap/v3" "github.com/golang/mock/gomock" @@ -19,6 +20,8 @@ import ( "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" + "go.pinniped.dev/internal/certauthority" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/testutil" ) @@ -924,9 +927,9 @@ func TestEndUserAuthentication(t *testing.T) { } dialWasAttempted := false - tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { dialWasAttempted = true - require.Equal(t, tt.providerConfig.Host, hostAndPort) + require.Equal(t, tt.providerConfig.Host, addr.Endpoint()) if tt.dialError != nil { return nil, tt.dialError } @@ -1059,9 +1062,9 @@ func TestTestConnection(t *testing.T) { } dialWasAttempted := false - tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { dialWasAttempted = true - require.Equal(t, tt.providerConfig.Host, hostAndPort) + require.Equal(t, tt.providerConfig.Host, addr.Endpoint()) if tt.dialError != nil { return nil, tt.dialError } @@ -1123,6 +1126,13 @@ func TestRealTLSDialing(t *testing.T) { require.NoError(t, err) testServerHostAndPort := parsedURL.Host + caForTestServerWithBadCertName, err := certauthority.New("Test CA", time.Hour) + require.NoError(t, err) + wrongIP := net.ParseIP("10.2.3.4") + cert, err := caForTestServerWithBadCertName.IssueServerCert([]string{"wrong-dns-name"}, []net.IP{wrongIP}, time.Hour) + require.NoError(t, err) + testServerWithBadCertNameAddr := testutil.TLSTestServerWithCert(t, func(w http.ResponseWriter, r *http.Request) {}, cert) + unusedPortGrabbingListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) recentlyClaimedHostAndPort := unusedPortGrabbingListener.Addr().String() @@ -1146,6 +1156,14 @@ func TestRealTLSDialing(t *testing.T) { connProto: TLS, context: context.Background(), }, + { + name: "server cert name does not match the address to which the client connected", + host: testServerWithBadCertNameAddr, + caBundle: caForTestServerWithBadCertName.Bundle(), + connProto: TLS, + context: context.Background(), + wantError: `LDAP Result Code 200 "Network Error": x509: certificate is valid for 10.2.3.4, not 127.0.0.1`, + }, { name: "invalid CA bundle with TLS", host: testServerHostAndPort, @@ -1168,7 +1186,7 @@ func TestRealTLSDialing(t *testing.T) { caBundle: []byte(testServerCABundle), connProto: TLS, context: context.Background(), - wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`, + wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`, }, { name: "invalid host with StartTLS", @@ -1176,7 +1194,7 @@ func TestRealTLSDialing(t *testing.T) { caBundle: []byte(testServerCABundle), connProto: StartTLS, context: context.Background(), - wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`, + wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`, }, { name: "missing CA bundle when it is required because the host is not using a trusted CA", @@ -1244,123 +1262,3 @@ func TestRealTLSDialing(t *testing.T) { }) } } - -// Test various cases of host and port parsing. -func TestHostAndPortWithDefaultPort(t *testing.T) { - tests := []struct { - name string - hostAndPort string - defaultPort string - wantError string - wantHostAndPort string - }{ - { - name: "host already has port", - hostAndPort: "host.example.com:99", - defaultPort: "42", - wantHostAndPort: "host.example.com:99", - }, - { - name: "host does not have port", - hostAndPort: "host.example.com", - defaultPort: "42", - wantHostAndPort: "host.example.com:42", - }, - { - name: "host does not have port and default port is empty", - hostAndPort: "host.example.com", - defaultPort: "", - wantHostAndPort: "host.example.com", - }, - { - name: "host has port and default port is empty", - hostAndPort: "host.example.com:42", - defaultPort: "", - wantHostAndPort: "host.example.com:42", - }, - { - name: "IPv6 host already has port", - hostAndPort: "[::1%lo0]:80", - defaultPort: "42", - wantHostAndPort: "[::1%lo0]:80", - }, - { - name: "IPv6 host does not have port", - hostAndPort: "[::1%lo0]", - defaultPort: "42", - wantHostAndPort: "[::1%lo0]:42", - }, - { - name: "IPv6 host does not have port and default port is empty", - hostAndPort: "[::1%lo0]", - defaultPort: "", - wantHostAndPort: "[::1%lo0]", - }, - { - name: "host is not valid", - hostAndPort: "host.example.com:port1:port2", - defaultPort: "42", - wantError: "address host.example.com:port1:port2: too many colons in address", - }, - } - for _, test := range tests { - tt := test - t.Run(tt.name, func(t *testing.T) { - hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort) - if tt.wantError != "" { - require.EqualError(t, err, tt.wantError) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantHostAndPort, hostAndPort) - }) - } -} - -// Test various cases of host and port parsing. -func TestHostWithoutPort(t *testing.T) { - tests := []struct { - name string - hostAndPort string - wantError string - wantHostAndPort string - }{ - { - name: "host already has port", - hostAndPort: "host.example.com:99", - wantHostAndPort: "host.example.com", - }, - { - name: "host does not have port", - hostAndPort: "host.example.com", - wantHostAndPort: "host.example.com", - }, - { - name: "IPv6 host already has port", - hostAndPort: "[::1%lo0]:80", - wantHostAndPort: "[::1%lo0]", - }, - { - name: "IPv6 host does not have port", - hostAndPort: "[::1%lo0]", - wantHostAndPort: "[::1%lo0]", - }, - { - name: "host is not valid", - hostAndPort: "host.example.com:port1:port2", - wantError: "address host.example.com:port1:port2: too many colons in address", - }, - } - for _, test := range tests { - tt := test - t.Run(tt.name, func(t *testing.T) { - hostAndPort, err := hostWithoutPort(tt.hostAndPort) - if tt.wantError != "" { - require.EqualError(t, err, tt.wantError) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantHostAndPort, hostAndPort) - }) - } -} diff --git a/test/integration/ldap_client_test.go b/test/integration/ldap_client_test.go index d62b51f4..765b1f45 100644 --- a/test/integration/ldap_client_test.go +++ b/test/integration/ldap_client_test.go @@ -344,7 +344,7 @@ func TestLDAPSearch(t *testing.T) { username: "pinny", password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })), - wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, + wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`, }, { name: "when the server is not parsable with StartTLS", @@ -355,7 +355,7 @@ func TestLDAPSearch(t *testing.T) { p.ConnectionProtocol = upstreamldap.StartTLS p.Host = "too:many:ports" })), - wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, + wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`, }, { name: "when the CA bundle is not parsable with TLS",