From b16e84d90a69330e9627e2974c2904112f25b972 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Fri, 21 May 2021 12:44:01 -0700 Subject: [PATCH 1/5] Add another unit test for the LDAP client code --- internal/testutil/tlsserver.go | 37 +++++++++++++++++++++- internal/upstreamldap/upstreamldap_test.go | 17 ++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/internal/testutil/tlsserver.go b/internal/testutil/tlsserver.go index c6d1a522..c55e07de 100644 --- a/internal/testutil/tlsserver.go +++ b/internal/testutil/tlsserver.go @@ -1,13 +1,18 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package testutil import ( + "crypto/tls" "encoding/pem" + "errors" + "net" "net/http" "net/http/httptest" "testing" + + "github.com/stretchr/testify/require" ) // TLSTestServer starts a test server listening on a local port using a test CA. It returns the PEM CA bundle and the @@ -23,3 +28,33 @@ func TLSTestServer(t *testing.T, handler http.HandlerFunc) (caBundlePEM string, })) return caBundle, server.URL } + +func TLSTestServerWithCert(t *testing.T, handler http.HandlerFunc, certificate *tls.Certificate) (url string) { + t.Helper() + + server := http.Server{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{*certificate}, + MinVersion: tls.VersionTLS12, + }, + Handler: handler, + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + go func() { + // Empty certFile and keyFile will use certs from Server.TLSConfig. + serveErr := server.ServeTLS(l, "", "") + if !errors.Is(serveErr, http.ErrServerClosed) { + t.Log("Got an unexpected error while starting the fake http server!") + require.NoError(t, serveErr) + } + }() + + t.Cleanup(func() { + _ = server.Close() + }) + + return l.Addr().String() +} diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index f66d69ce..9feb4b64 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "testing" + "time" "github.com/go-ldap/ldap/v3" "github.com/golang/mock/gomock" @@ -19,6 +20,7 @@ import ( "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" + "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/testutil" ) @@ -1123,6 +1125,13 @@ func TestRealTLSDialing(t *testing.T) { require.NoError(t, err) testServerHostAndPort := parsedURL.Host + caForTestServerWithBadCertName, err := certauthority.New("Test CA", time.Hour) + require.NoError(t, err) + wrongIP := net.ParseIP("10.2.3.4") + cert, err := caForTestServerWithBadCertName.IssueServerCert([]string{"wrong-dns-name"}, []net.IP{wrongIP}, time.Hour) + require.NoError(t, err) + testServerWithBadCertNameAddr := testutil.TLSTestServerWithCert(t, func(w http.ResponseWriter, r *http.Request) {}, cert) + unusedPortGrabbingListener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) recentlyClaimedHostAndPort := unusedPortGrabbingListener.Addr().String() @@ -1146,6 +1155,14 @@ func TestRealTLSDialing(t *testing.T) { connProto: TLS, context: context.Background(), }, + { + name: "server cert name does not match the address to which the client connected", + host: testServerWithBadCertNameAddr, + caBundle: caForTestServerWithBadCertName.Bundle(), + connProto: TLS, + context: context.Background(), + wantError: `LDAP Result Code 200 "Network Error": x509: certificate is valid for 10.2.3.4, not 127.0.0.1`, + }, { name: "invalid CA bundle with TLS", host: testServerHostAndPort, From 2014f4623dee3f0a22ca7615a5f7638f04e239d7 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Mon, 24 May 2021 14:24:09 -0700 Subject: [PATCH 2/5] Move require.NoError() to t.Cleanup() --- internal/testutil/tlsserver.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/internal/testutil/tlsserver.go b/internal/testutil/tlsserver.go index c55e07de..13e39324 100644 --- a/internal/testutil/tlsserver.go +++ b/internal/testutil/tlsserver.go @@ -43,17 +43,19 @@ func TLSTestServerWithCert(t *testing.T, handler http.HandlerFunc, certificate * l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) + serverShutdownChan := make(chan error) go func() { // Empty certFile and keyFile will use certs from Server.TLSConfig. - serveErr := server.ServeTLS(l, "", "") - if !errors.Is(serveErr, http.ErrServerClosed) { - t.Log("Got an unexpected error while starting the fake http server!") - require.NoError(t, serveErr) - } + serverShutdownChan <- server.ServeTLS(l, "", "") }() t.Cleanup(func() { _ = server.Close() + serveErr := <-serverShutdownChan + if !errors.Is(serveErr, http.ErrServerClosed) { + t.Log("Got an unexpected error while starting the fake http server!") + require.NoError(t, serveErr) + } }) return l.Addr().String() From f89f2281d86703a25408cb89af2d817232dc2915 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 25 May 2021 05:51:17 +0000 Subject: [PATCH 3/5] Bump github.com/google/go-cmp from 0.5.5 to 0.5.6 Bumps [github.com/google/go-cmp](https://github.com/google/go-cmp) from 0.5.5 to 0.5.6. - [Release notes](https://github.com/google/go-cmp/releases) - [Commits](https://github.com/google/go-cmp/compare/v0.5.5...v0.5.6) Signed-off-by: dependabot[bot] --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 8004534f..82b6a2a0 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/go-openapi/spec v0.20.3 // indirect github.com/gofrs/flock v0.8.0 github.com/golang/mock v1.5.0 - github.com/google/go-cmp v0.5.5 + github.com/google/go-cmp v0.5.6 github.com/google/gofuzz v1.2.0 github.com/gorilla/securecookie v1.1.1 github.com/gorilla/websocket v1.4.2 diff --git a/go.sum b/go.sum index 30a8df74..28caead8 100644 --- a/go.sum +++ b/go.sum @@ -514,8 +514,8 @@ github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMyw github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-jsonnet v0.16.0/go.mod h1:sOcuej3UW1vpPTZOr8L7RQimqai1a57bt5j22LzGZCw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= From d9a3992b3b6445ea37ebd693b40405b719079f18 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Tue, 25 May 2021 14:32:57 -0500 Subject: [PATCH 4/5] Add endpointaddr pkg for parsing host+port inputs. This type of field appears in more than one of our APIs, so this package will provide a single source of truth for validating and parsing inputs. Signed-off-by: Matt Moyer --- internal/endpointaddr/endpointaddr.go | 71 ++++++++ internal/endpointaddr/endpointaddr_test.go | 182 +++++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 internal/endpointaddr/endpointaddr.go create mode 100644 internal/endpointaddr/endpointaddr_test.go diff --git a/internal/endpointaddr/endpointaddr.go b/internal/endpointaddr/endpointaddr.go new file mode 100644 index 00000000..fd927765 --- /dev/null +++ b/internal/endpointaddr/endpointaddr.go @@ -0,0 +1,71 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package endpointaddr implements parsing and validation of "[:]" strings for Pinniped APIs. +package endpointaddr + +import ( + "fmt" + "net" + "strconv" + + "k8s.io/apimachinery/pkg/util/validation" +) + +type HostPort struct { + // Host is the validated host part of the input, which may be a hostname or IP. + // + // This string can be be used as an x509 certificate SAN. + Host string + + // Port is the validated port number, which may be defaulted. + Port uint16 +} + +// Endpoint is the host:port validated from the input, where port may be a default value. +// +// This string can be passed to net.Dial. +func (h *HostPort) Endpoint() string { + return net.JoinHostPort(h.Host, strconv.Itoa(int(h.Port))) +} + +// Parse an "endpoint address" string, providing a default port. The input can be in several valid formats: +// +// - "" (DNS hostname) +// - "" (IPv4 address) +// - "" (IPv6 address) +// - ":" (DNS hostname with port) +// - ":" (IPv4 address with port) +// - "[]:" (IPv6 address with port, brackets are required) +// +// If the input does not not specify a port number, then defaultPort will be used. +func Parse(endpoint string, defaultPort uint16) (HostPort, error) { + // Try parsing it both with and without an implicit port 443 at the end. + host, port, err := net.SplitHostPort(endpoint) + + // If we got an error parsing the raw input, try adding the default port. + if err != nil { + host, port, err = net.SplitHostPort(net.JoinHostPort(endpoint, strconv.Itoa(int(defaultPort)))) + } + + // Give up if there's still an error splitting the host and port. + if err != nil { + return HostPort{}, err + } + + // Parse the port number is an integer in the range of valid ports. + integerPort, _ := strconv.Atoi(port) + if len(validation.IsValidPortNum(integerPort)) > 0 { + return HostPort{}, fmt.Errorf("invalid port %q", port) + } + + // Check if the host part is a IPv4 or IPv6 address or a valid hostname according to RFC 1123. + switch { + case len(validation.IsValidIP(host)) == 0: + case len(validation.IsDNS1123Subdomain(host)) == 0: + default: + return HostPort{}, fmt.Errorf("host %q is not a valid hostname or IP address", host) + } + + return HostPort{Host: host, Port: uint16(integerPort)}, nil +} diff --git a/internal/endpointaddr/endpointaddr_test.go b/internal/endpointaddr/endpointaddr_test.go new file mode 100644 index 00000000..736df312 --- /dev/null +++ b/internal/endpointaddr/endpointaddr_test.go @@ -0,0 +1,182 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package endpointaddr + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + t.Parallel() + for _, tt := range []struct { + name string + input string + defaultPort uint16 + expectErr string + expect HostPort + expectEndpoint string + }{ + { + name: "plain IPv4", + input: "127.0.0.1", + defaultPort: 443, + expect: HostPort{Host: "127.0.0.1", Port: 443}, + expectEndpoint: "127.0.0.1:443", + }, + { + name: "IPv4 with port", + input: "127.0.0.1:8443", + defaultPort: 443, + expect: HostPort{Host: "127.0.0.1", Port: 8443}, + expectEndpoint: "127.0.0.1:8443", + }, + { + name: "IPv4 in brackets with port", + input: "[127.0.0.1]:8443", + defaultPort: 443, + expect: HostPort{Host: "127.0.0.1", Port: 8443}, + expectEndpoint: "127.0.0.1:8443", + }, + { + name: "IPv4 as IPv6 in brackets with port", + input: "[::127.0.0.1]:8443", + defaultPort: 443, + expect: HostPort{Host: "::127.0.0.1", Port: 8443}, + expectEndpoint: "[::127.0.0.1]:8443", + }, + { + name: "IPv4 as IPv6 without port", + input: "::127.0.0.1", + defaultPort: 443, + expect: HostPort{Host: "::127.0.0.1", Port: 443}, + expectEndpoint: "[::127.0.0.1]:443", + }, + { + name: "plain IPv6 without port", + input: "2001:db8::ffff", + defaultPort: 443, + expect: HostPort{Host: "2001:db8::ffff", Port: 443}, + expectEndpoint: "[2001:db8::ffff]:443", + }, + { + name: "IPv6 with port", + input: "[2001:db8::ffff]:8443", + defaultPort: 443, + expect: HostPort{Host: "2001:db8::ffff", Port: 8443}, + expectEndpoint: "[2001:db8::ffff]:8443", + }, + { + name: "plain hostname", + input: "host.example.com", + defaultPort: 443, + expect: HostPort{Host: "host.example.com", Port: 443}, + expectEndpoint: "host.example.com:443", + }, + { + name: "plain hostname with dash", + input: "host-dev.example.com", + defaultPort: 443, + expect: HostPort{Host: "host-dev.example.com", Port: 443}, + expectEndpoint: "host-dev.example.com:443", + }, + { + name: "hostname with port", + input: "host.example.com:8443", + defaultPort: 443, + expect: HostPort{Host: "host.example.com", Port: 8443}, + expectEndpoint: "host.example.com:8443", + }, + { + name: "hostname in brackets with port", + input: "[host.example.com]:8443", + defaultPort: 443, + expect: HostPort{Host: "host.example.com", Port: 8443}, + expectEndpoint: "host.example.com:8443", + }, + { + name: "hostname without dots", + input: "localhost", + defaultPort: 443, + expect: HostPort{Host: "localhost", Port: 443}, + expectEndpoint: "localhost:443", + }, + { + name: "hostname and port without dots", + input: "localhost:8443", + defaultPort: 443, + expect: HostPort{Host: "localhost", Port: 8443}, + expectEndpoint: "localhost:8443", + }, + { + name: "invalid empty string", + input: "", + defaultPort: 443, + expectErr: `host "" is not a valid hostname or IP address`, + }, + { + // IPv6 zone index specifiers are not yet supported. + name: "IPv6 with port and zone index", + input: "[2001:db8::ffff%lo0]:8443", + defaultPort: 443, + expectErr: `host "2001:db8::ffff%lo0" is not a valid hostname or IP address`, + }, + { + name: "IPv6 in brackets without port", + input: "[2001:db8::ffff]", + defaultPort: 443, + expectErr: `address [[2001:db8::ffff]]:443: missing port in address`, + }, + { + name: "invalid HTTPS URL", + input: "https://host.example.com", + defaultPort: 443, + expectErr: `invalid port "//host.example.com"`, + }, + { + name: "invalid host with URL path", + input: "host.example.com/some/path", + defaultPort: 443, + expectErr: `host "host.example.com/some/path" is not a valid hostname or IP address`, + }, + { + name: "invalid host with mismatched brackets", + input: "[host.example.com", + defaultPort: 443, + expectErr: "address [host.example.com:443: missing ']' in address", + }, + { + name: "invalid host with underscores", + input: "___.example.com:1234", + defaultPort: 443, + expectErr: `host "___.example.com" is not a valid hostname or IP address`, + }, + { + name: "invalid host with uppercase", + input: "HOST.EXAMPLE.COM", + defaultPort: 443, + expectErr: `host "HOST.EXAMPLE.COM" is not a valid hostname or IP address`, + }, + { + name: "invalid host with extra port", + input: "host.example.com:port1:port2", + defaultPort: 443, + expectErr: `host "host.example.com:port1:port2" is not a valid hostname or IP address`, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := Parse(tt.input, tt.defaultPort) + if tt.expectErr == "" { + assert.NoError(t, err) + assert.Equal(t, tt.expect, got) + assert.Equal(t, tt.expectEndpoint, got.Endpoint()) + } else { + assert.EqualError(t, err, tt.expectErr) + assert.Equal(t, HostPort{}, got) + } + }) + } +} From 89eff285499a6190eed18cd9626e1acc2112270f Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Tue, 25 May 2021 14:46:50 -0500 Subject: [PATCH 5/5] Convert LDAP code to use endpointaddr package. Signed-off-by: Matt Moyer --- .../ldap_upstream_watcher_test.go | 5 +- internal/upstreamldap/upstreamldap.go | 75 +++------- internal/upstreamldap/upstreamldap_test.go | 133 +----------------- test/integration/ldap_client_test.go | 4 +- 4 files changed, 30 insertions(+), 187 deletions(-) diff --git a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go index 6eaab2f3..c801e741 100644 --- a/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go +++ b/internal/controller/supervisorconfig/ldapupstreamwatcher/ldap_upstream_watcher_test.go @@ -26,6 +26,7 @@ import ( pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions" "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" @@ -871,9 +872,9 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { tt.setupMocks(conn) } - dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) { + dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (upstreamldap.Conn, error) { if tt.dialErrors != nil { - dialErr := tt.dialErrors[hostAndPort] + dialErr := tt.dialErrors[addr.Endpoint()] if dialErr != nil { return nil, dialErr } diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 54c51293..f857c730 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -21,6 +21,7 @@ import ( "k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/user" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/plog" ) @@ -30,6 +31,8 @@ const ( commonNameAttributeName = "cn" searchFilterInterpolationLocationMarker = "{}" groupSearchPageSize = uint32(250) + defaultLDAPPort = uint16(389) + defaultLDAPSPort = uint16(636) ) // Conn abstracts the upstream LDAP communication protocol (mostly for testing). @@ -48,16 +51,16 @@ var _ Conn = &ldap.Conn{} // LDAPDialer is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP. type LDAPDialer interface { - Dial(ctx context.Context, hostAndPort string) (Conn, error) + Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) } // LDAPDialerFunc makes it easy to use a func as an LDAPDialer. -type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error) +type LDAPDialerFunc func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) var _ LDAPDialer = LDAPDialerFunc(nil) -func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, error) { - return f(ctx, hostAndPort) +func (f LDAPDialerFunc) Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { + return f(ctx, addr) } type LDAPConnectionProtocol string @@ -147,26 +150,26 @@ func (p *Provider) GetConfig() ProviderConfig { } func (p *Provider) dial(ctx context.Context) (Conn, error) { - tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort) + tlsAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPSPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } - startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort) + startTLSAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPPort) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } // Choose how and where to dial based on TLS vs. StartTLS config option. var dialFunc LDAPDialerFunc - var hostAndPort string + var addr endpointaddr.HostPort switch { case p.c.ConnectionProtocol == TLS: dialFunc = p.dialTLS - hostAndPort = tlsHostAndPort + addr = tlsAddr case p.c.ConnectionProtocol == StartTLS: dialFunc = p.dialStartTLS - hostAndPort = startTLSHostAndPort + addr = startTLSAddr default: return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol")) } @@ -176,20 +179,20 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) { dialFunc = p.c.Dialer.Dial } - return dialFunc(ctx, hostAndPort) + return dialFunc(ctx, addr) } // dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS. // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // so we implement it ourselves, heavily inspired by ldap.DialURL. -func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { +func (p *Provider) dialTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { tlsConfig, err := p.tlsConfig() if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig} - c, err := dialer.DialContext(ctx, "tcp", hostAndPort) + c, err := dialer.DialContext(ctx, "tcp", addr.Endpoint()) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } @@ -202,20 +205,16 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error // dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS. // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // so we implement it ourselves, heavily inspired by ldap.DialURL. -func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) { +func (p *Provider) dialStartTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { tlsConfig, err := p.tlsConfig() if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } - host, err := hostWithoutPort(hostAndPort) - if err != nil { - return nil, ldap.NewError(ldap.ErrorNetwork, err) - } // Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS. - tlsConfig.ServerName = host + tlsConfig.ServerName = addr.Host - c, err := netDialer().DialContext(ctx, "tcp", hostAndPort) + c, err := netDialer().DialContext(ctx, "tcp", addr.Endpoint()) if err != nil { return nil, ldap.NewError(ldap.ErrorNetwork, err) } @@ -245,44 +244,6 @@ func (p *Provider) tlsConfig() (*tls.Config, error) { return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil } -// Adds the default port if hostAndPort did not already include a port. -func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) { - host, port, err := net.SplitHostPort(hostAndPort) - if err != nil { - if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare - host = hostAndPort - port = defaultPort - } else { - return "", err // hostAndPort argument was not parsable - } - } - switch { - case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"): - // don't add extra square brackets to an IPv6 address that already has them - return fmt.Sprintf("%s:%s", host, port), nil - case port != "": - return net.JoinHostPort(host, port), nil - default: - return host, nil - } -} - -// Strip the port from a host or host:port. -func hostWithoutPort(hostAndPort string) (string, error) { - host, _, err := net.SplitHostPort(hostAndPort) - if err != nil { - if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare - return hostAndPort, nil - } - return "", err // hostAndPort argument was not parsable - } - if strings.HasPrefix(hostAndPort, "[") { - // it was an IPv6 address, so preserve the square brackets. - return fmt.Sprintf("[%s]", host), nil - } - return host, nil -} - // A name for this upstream provider. func (p *Provider) GetName() string { return p.c.Name diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 9feb4b64..fd8b9658 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -21,6 +21,7 @@ import ( "k8s.io/apiserver/pkg/authentication/user" "go.pinniped.dev/internal/certauthority" + "go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/testutil" ) @@ -926,9 +927,9 @@ func TestEndUserAuthentication(t *testing.T) { } dialWasAttempted := false - tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { dialWasAttempted = true - require.Equal(t, tt.providerConfig.Host, hostAndPort) + require.Equal(t, tt.providerConfig.Host, addr.Endpoint()) if tt.dialError != nil { return nil, tt.dialError } @@ -1061,9 +1062,9 @@ func TestTestConnection(t *testing.T) { } dialWasAttempted := false - tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { dialWasAttempted = true - require.Equal(t, tt.providerConfig.Host, hostAndPort) + require.Equal(t, tt.providerConfig.Host, addr.Endpoint()) if tt.dialError != nil { return nil, tt.dialError } @@ -1185,7 +1186,7 @@ func TestRealTLSDialing(t *testing.T) { caBundle: []byte(testServerCABundle), connProto: TLS, context: context.Background(), - wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`, + wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`, }, { name: "invalid host with StartTLS", @@ -1193,7 +1194,7 @@ func TestRealTLSDialing(t *testing.T) { caBundle: []byte(testServerCABundle), connProto: StartTLS, context: context.Background(), - wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`, + wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`, }, { name: "missing CA bundle when it is required because the host is not using a trusted CA", @@ -1261,123 +1262,3 @@ func TestRealTLSDialing(t *testing.T) { }) } } - -// Test various cases of host and port parsing. -func TestHostAndPortWithDefaultPort(t *testing.T) { - tests := []struct { - name string - hostAndPort string - defaultPort string - wantError string - wantHostAndPort string - }{ - { - name: "host already has port", - hostAndPort: "host.example.com:99", - defaultPort: "42", - wantHostAndPort: "host.example.com:99", - }, - { - name: "host does not have port", - hostAndPort: "host.example.com", - defaultPort: "42", - wantHostAndPort: "host.example.com:42", - }, - { - name: "host does not have port and default port is empty", - hostAndPort: "host.example.com", - defaultPort: "", - wantHostAndPort: "host.example.com", - }, - { - name: "host has port and default port is empty", - hostAndPort: "host.example.com:42", - defaultPort: "", - wantHostAndPort: "host.example.com:42", - }, - { - name: "IPv6 host already has port", - hostAndPort: "[::1%lo0]:80", - defaultPort: "42", - wantHostAndPort: "[::1%lo0]:80", - }, - { - name: "IPv6 host does not have port", - hostAndPort: "[::1%lo0]", - defaultPort: "42", - wantHostAndPort: "[::1%lo0]:42", - }, - { - name: "IPv6 host does not have port and default port is empty", - hostAndPort: "[::1%lo0]", - defaultPort: "", - wantHostAndPort: "[::1%lo0]", - }, - { - name: "host is not valid", - hostAndPort: "host.example.com:port1:port2", - defaultPort: "42", - wantError: "address host.example.com:port1:port2: too many colons in address", - }, - } - for _, test := range tests { - tt := test - t.Run(tt.name, func(t *testing.T) { - hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort) - if tt.wantError != "" { - require.EqualError(t, err, tt.wantError) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantHostAndPort, hostAndPort) - }) - } -} - -// Test various cases of host and port parsing. -func TestHostWithoutPort(t *testing.T) { - tests := []struct { - name string - hostAndPort string - wantError string - wantHostAndPort string - }{ - { - name: "host already has port", - hostAndPort: "host.example.com:99", - wantHostAndPort: "host.example.com", - }, - { - name: "host does not have port", - hostAndPort: "host.example.com", - wantHostAndPort: "host.example.com", - }, - { - name: "IPv6 host already has port", - hostAndPort: "[::1%lo0]:80", - wantHostAndPort: "[::1%lo0]", - }, - { - name: "IPv6 host does not have port", - hostAndPort: "[::1%lo0]", - wantHostAndPort: "[::1%lo0]", - }, - { - name: "host is not valid", - hostAndPort: "host.example.com:port1:port2", - wantError: "address host.example.com:port1:port2: too many colons in address", - }, - } - for _, test := range tests { - tt := test - t.Run(tt.name, func(t *testing.T) { - hostAndPort, err := hostWithoutPort(tt.hostAndPort) - if tt.wantError != "" { - require.EqualError(t, err, tt.wantError) - } else { - require.NoError(t, err) - } - require.Equal(t, tt.wantHostAndPort, hostAndPort) - }) - } -} diff --git a/test/integration/ldap_client_test.go b/test/integration/ldap_client_test.go index d62b51f4..765b1f45 100644 --- a/test/integration/ldap_client_test.go +++ b/test/integration/ldap_client_test.go @@ -344,7 +344,7 @@ func TestLDAPSearch(t *testing.T) { username: "pinny", password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })), - wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, + wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`, }, { name: "when the server is not parsable with StartTLS", @@ -355,7 +355,7 @@ func TestLDAPSearch(t *testing.T) { p.ConnectionProtocol = upstreamldap.StartTLS p.Host = "too:many:ports" })), - wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`, + wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`, }, { name: "when the CA bundle is not parsable with TLS",