Merge pull request #640 from mattmoyer/endpointaddr

Refactor "endpoint address" parsing code into shared package.
This commit is contained in:
Matt Moyer 2021-05-25 17:22:31 -05:00 committed by GitHub
commit 18a2a27a06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 283 additions and 187 deletions

View File

@ -26,6 +26,7 @@ import (
pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions"
"go.pinniped.dev/internal/certauthority"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/mocks/mockldapconn"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
@ -871,9 +872,9 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
tt.setupMocks(conn)
}
dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) {
dialer := &comparableDialer{upstreamldap.LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (upstreamldap.Conn, error) {
if tt.dialErrors != nil {
dialErr := tt.dialErrors[hostAndPort]
dialErr := tt.dialErrors[addr.Endpoint()]
if dialErr != nil {
return nil, dialErr
}

View File

@ -0,0 +1,71 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package endpointaddr implements parsing and validation of "<host>[:<port>]" strings for Pinniped APIs.
package endpointaddr
import (
"fmt"
"net"
"strconv"
"k8s.io/apimachinery/pkg/util/validation"
)
type HostPort struct {
// Host is the validated host part of the input, which may be a hostname or IP.
//
// This string can be be used as an x509 certificate SAN.
Host string
// Port is the validated port number, which may be defaulted.
Port uint16
}
// Endpoint is the host:port validated from the input, where port may be a default value.
//
// This string can be passed to net.Dial.
func (h *HostPort) Endpoint() string {
return net.JoinHostPort(h.Host, strconv.Itoa(int(h.Port)))
}
// Parse an "endpoint address" string, providing a default port. The input can be in several valid formats:
//
// - "<hostname>" (DNS hostname)
// - "<IPv4>" (IPv4 address)
// - "<IPv6>" (IPv6 address)
// - "<hostname>:<port>" (DNS hostname with port)
// - "<IPv4>:<port>" (IPv4 address with port)
// - "[<IPv6>]:<port>" (IPv6 address with port, brackets are required)
//
// If the input does not not specify a port number, then defaultPort will be used.
func Parse(endpoint string, defaultPort uint16) (HostPort, error) {
// Try parsing it both with and without an implicit port 443 at the end.
host, port, err := net.SplitHostPort(endpoint)
// If we got an error parsing the raw input, try adding the default port.
if err != nil {
host, port, err = net.SplitHostPort(net.JoinHostPort(endpoint, strconv.Itoa(int(defaultPort))))
}
// Give up if there's still an error splitting the host and port.
if err != nil {
return HostPort{}, err
}
// Parse the port number is an integer in the range of valid ports.
integerPort, _ := strconv.Atoi(port)
if len(validation.IsValidPortNum(integerPort)) > 0 {
return HostPort{}, fmt.Errorf("invalid port %q", port)
}
// Check if the host part is a IPv4 or IPv6 address or a valid hostname according to RFC 1123.
switch {
case len(validation.IsValidIP(host)) == 0:
case len(validation.IsDNS1123Subdomain(host)) == 0:
default:
return HostPort{}, fmt.Errorf("host %q is not a valid hostname or IP address", host)
}
return HostPort{Host: host, Port: uint16(integerPort)}, nil
}

View File

@ -0,0 +1,182 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package endpointaddr
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParse(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
input string
defaultPort uint16
expectErr string
expect HostPort
expectEndpoint string
}{
{
name: "plain IPv4",
input: "127.0.0.1",
defaultPort: 443,
expect: HostPort{Host: "127.0.0.1", Port: 443},
expectEndpoint: "127.0.0.1:443",
},
{
name: "IPv4 with port",
input: "127.0.0.1:8443",
defaultPort: 443,
expect: HostPort{Host: "127.0.0.1", Port: 8443},
expectEndpoint: "127.0.0.1:8443",
},
{
name: "IPv4 in brackets with port",
input: "[127.0.0.1]:8443",
defaultPort: 443,
expect: HostPort{Host: "127.0.0.1", Port: 8443},
expectEndpoint: "127.0.0.1:8443",
},
{
name: "IPv4 as IPv6 in brackets with port",
input: "[::127.0.0.1]:8443",
defaultPort: 443,
expect: HostPort{Host: "::127.0.0.1", Port: 8443},
expectEndpoint: "[::127.0.0.1]:8443",
},
{
name: "IPv4 as IPv6 without port",
input: "::127.0.0.1",
defaultPort: 443,
expect: HostPort{Host: "::127.0.0.1", Port: 443},
expectEndpoint: "[::127.0.0.1]:443",
},
{
name: "plain IPv6 without port",
input: "2001:db8::ffff",
defaultPort: 443,
expect: HostPort{Host: "2001:db8::ffff", Port: 443},
expectEndpoint: "[2001:db8::ffff]:443",
},
{
name: "IPv6 with port",
input: "[2001:db8::ffff]:8443",
defaultPort: 443,
expect: HostPort{Host: "2001:db8::ffff", Port: 8443},
expectEndpoint: "[2001:db8::ffff]:8443",
},
{
name: "plain hostname",
input: "host.example.com",
defaultPort: 443,
expect: HostPort{Host: "host.example.com", Port: 443},
expectEndpoint: "host.example.com:443",
},
{
name: "plain hostname with dash",
input: "host-dev.example.com",
defaultPort: 443,
expect: HostPort{Host: "host-dev.example.com", Port: 443},
expectEndpoint: "host-dev.example.com:443",
},
{
name: "hostname with port",
input: "host.example.com:8443",
defaultPort: 443,
expect: HostPort{Host: "host.example.com", Port: 8443},
expectEndpoint: "host.example.com:8443",
},
{
name: "hostname in brackets with port",
input: "[host.example.com]:8443",
defaultPort: 443,
expect: HostPort{Host: "host.example.com", Port: 8443},
expectEndpoint: "host.example.com:8443",
},
{
name: "hostname without dots",
input: "localhost",
defaultPort: 443,
expect: HostPort{Host: "localhost", Port: 443},
expectEndpoint: "localhost:443",
},
{
name: "hostname and port without dots",
input: "localhost:8443",
defaultPort: 443,
expect: HostPort{Host: "localhost", Port: 8443},
expectEndpoint: "localhost:8443",
},
{
name: "invalid empty string",
input: "",
defaultPort: 443,
expectErr: `host "" is not a valid hostname or IP address`,
},
{
// IPv6 zone index specifiers are not yet supported.
name: "IPv6 with port and zone index",
input: "[2001:db8::ffff%lo0]:8443",
defaultPort: 443,
expectErr: `host "2001:db8::ffff%lo0" is not a valid hostname or IP address`,
},
{
name: "IPv6 in brackets without port",
input: "[2001:db8::ffff]",
defaultPort: 443,
expectErr: `address [[2001:db8::ffff]]:443: missing port in address`,
},
{
name: "invalid HTTPS URL",
input: "https://host.example.com",
defaultPort: 443,
expectErr: `invalid port "//host.example.com"`,
},
{
name: "invalid host with URL path",
input: "host.example.com/some/path",
defaultPort: 443,
expectErr: `host "host.example.com/some/path" is not a valid hostname or IP address`,
},
{
name: "invalid host with mismatched brackets",
input: "[host.example.com",
defaultPort: 443,
expectErr: "address [host.example.com:443: missing ']' in address",
},
{
name: "invalid host with underscores",
input: "___.example.com:1234",
defaultPort: 443,
expectErr: `host "___.example.com" is not a valid hostname or IP address`,
},
{
name: "invalid host with uppercase",
input: "HOST.EXAMPLE.COM",
defaultPort: 443,
expectErr: `host "HOST.EXAMPLE.COM" is not a valid hostname or IP address`,
},
{
name: "invalid host with extra port",
input: "host.example.com:port1:port2",
defaultPort: 443,
expectErr: `host "host.example.com:port1:port2" is not a valid hostname or IP address`,
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
got, err := Parse(tt.input, tt.defaultPort)
if tt.expectErr == "" {
assert.NoError(t, err)
assert.Equal(t, tt.expect, got)
assert.Equal(t, tt.expectEndpoint, got.Endpoint())
} else {
assert.EqualError(t, err, tt.expectErr)
assert.Equal(t, HostPort{}, got)
}
})
}
}

View File

@ -21,6 +21,7 @@ import (
"k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user"
"go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/plog"
)
@ -30,6 +31,8 @@ const (
commonNameAttributeName = "cn"
searchFilterInterpolationLocationMarker = "{}"
groupSearchPageSize = uint32(250)
defaultLDAPPort = uint16(389)
defaultLDAPSPort = uint16(636)
)
// Conn abstracts the upstream LDAP communication protocol (mostly for testing).
@ -48,16 +51,16 @@ var _ Conn = &ldap.Conn{}
// LDAPDialer is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP.
type LDAPDialer interface {
Dial(ctx context.Context, hostAndPort string) (Conn, error)
Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error)
}
// LDAPDialerFunc makes it easy to use a func as an LDAPDialer.
type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error)
type LDAPDialerFunc func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error)
var _ LDAPDialer = LDAPDialerFunc(nil)
func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, error) {
return f(ctx, hostAndPort)
func (f LDAPDialerFunc) Dial(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
return f(ctx, addr)
}
type LDAPConnectionProtocol string
@ -147,26 +150,26 @@ func (p *Provider) GetConfig() ProviderConfig {
}
func (p *Provider) dial(ctx context.Context) (Conn, error) {
tlsHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapsPort)
tlsAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPSPort)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
startTLSHostAndPort, err := hostAndPortWithDefaultPort(p.c.Host, ldap.DefaultLdapPort)
startTLSAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPPort)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
// Choose how and where to dial based on TLS vs. StartTLS config option.
var dialFunc LDAPDialerFunc
var hostAndPort string
var addr endpointaddr.HostPort
switch {
case p.c.ConnectionProtocol == TLS:
dialFunc = p.dialTLS
hostAndPort = tlsHostAndPort
addr = tlsAddr
case p.c.ConnectionProtocol == StartTLS:
dialFunc = p.dialStartTLS
hostAndPort = startTLSHostAndPort
addr = startTLSAddr
default:
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("did not specify valid ConnectionProtocol"))
}
@ -176,20 +179,20 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) {
dialFunc = p.c.Dialer.Dial
}
return dialFunc(ctx, hostAndPort)
return dialFunc(ctx, addr)
}
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is TLS.
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
// so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
func (p *Provider) dialTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
tlsConfig, err := p.tlsConfig()
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
dialer := &tls.Dialer{NetDialer: netDialer(), Config: tlsConfig}
c, err := dialer.DialContext(ctx, "tcp", hostAndPort)
c, err := dialer.DialContext(ctx, "tcp", addr.Endpoint())
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
@ -202,20 +205,16 @@ func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error
// dialTLS is a default implementation of the Dialer, used when Dialer is nil and ConnectionProtocol is StartTLS.
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
// so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialStartTLS(ctx context.Context, hostAndPort string) (Conn, error) {
func (p *Provider) dialStartTLS(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
tlsConfig, err := p.tlsConfig()
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
host, err := hostWithoutPort(hostAndPort)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
// Unfortunately, this seems to be required for StartTLS, even though it is not needed for regular TLS.
tlsConfig.ServerName = host
tlsConfig.ServerName = addr.Host
c, err := netDialer().DialContext(ctx, "tcp", hostAndPort)
c, err := netDialer().DialContext(ctx, "tcp", addr.Endpoint())
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
@ -245,44 +244,6 @@ func (p *Provider) tlsConfig() (*tls.Config, error) {
return &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: rootCAs}, nil
}
// Adds the default port if hostAndPort did not already include a port.
func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) {
host, port, err := net.SplitHostPort(hostAndPort)
if err != nil {
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
host = hostAndPort
port = defaultPort
} else {
return "", err // hostAndPort argument was not parsable
}
}
switch {
case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
// don't add extra square brackets to an IPv6 address that already has them
return fmt.Sprintf("%s:%s", host, port), nil
case port != "":
return net.JoinHostPort(host, port), nil
default:
return host, nil
}
}
// Strip the port from a host or host:port.
func hostWithoutPort(hostAndPort string) (string, error) {
host, _, err := net.SplitHostPort(hostAndPort)
if err != nil {
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
return hostAndPort, nil
}
return "", err // hostAndPort argument was not parsable
}
if strings.HasPrefix(hostAndPort, "[") {
// it was an IPv6 address, so preserve the square brackets.
return fmt.Sprintf("[%s]", host), nil
}
return host, nil
}
// A name for this upstream provider.
func (p *Provider) GetName() string {
return p.c.Name

View File

@ -21,6 +21,7 @@ import (
"k8s.io/apiserver/pkg/authentication/user"
"go.pinniped.dev/internal/certauthority"
"go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/mocks/mockldapconn"
"go.pinniped.dev/internal/testutil"
)
@ -926,9 +927,9 @@ func TestEndUserAuthentication(t *testing.T) {
}
dialWasAttempted := false
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) {
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
dialWasAttempted = true
require.Equal(t, tt.providerConfig.Host, hostAndPort)
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
if tt.dialError != nil {
return nil, tt.dialError
}
@ -1061,9 +1062,9 @@ func TestTestConnection(t *testing.T) {
}
dialWasAttempted := false
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) {
tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
dialWasAttempted = true
require.Equal(t, tt.providerConfig.Host, hostAndPort)
require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
if tt.dialError != nil {
return nil, tt.dialError
}
@ -1185,7 +1186,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: []byte(testServerCABundle),
connProto: TLS,
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`,
},
{
name: "invalid host with StartTLS",
@ -1193,7 +1194,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: []byte(testServerCABundle),
connProto: StartTLS,
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": address this:is:not:a:valid:hostname: too many colons in address`,
wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`,
},
{
name: "missing CA bundle when it is required because the host is not using a trusted CA",
@ -1261,123 +1262,3 @@ func TestRealTLSDialing(t *testing.T) {
})
}
}
// Test various cases of host and port parsing.
func TestHostAndPortWithDefaultPort(t *testing.T) {
tests := []struct {
name string
hostAndPort string
defaultPort string
wantError string
wantHostAndPort string
}{
{
name: "host already has port",
hostAndPort: "host.example.com:99",
defaultPort: "42",
wantHostAndPort: "host.example.com:99",
},
{
name: "host does not have port",
hostAndPort: "host.example.com",
defaultPort: "42",
wantHostAndPort: "host.example.com:42",
},
{
name: "host does not have port and default port is empty",
hostAndPort: "host.example.com",
defaultPort: "",
wantHostAndPort: "host.example.com",
},
{
name: "host has port and default port is empty",
hostAndPort: "host.example.com:42",
defaultPort: "",
wantHostAndPort: "host.example.com:42",
},
{
name: "IPv6 host already has port",
hostAndPort: "[::1%lo0]:80",
defaultPort: "42",
wantHostAndPort: "[::1%lo0]:80",
},
{
name: "IPv6 host does not have port",
hostAndPort: "[::1%lo0]",
defaultPort: "42",
wantHostAndPort: "[::1%lo0]:42",
},
{
name: "IPv6 host does not have port and default port is empty",
hostAndPort: "[::1%lo0]",
defaultPort: "",
wantHostAndPort: "[::1%lo0]",
},
{
name: "host is not valid",
hostAndPort: "host.example.com:port1:port2",
defaultPort: "42",
wantError: "address host.example.com:port1:port2: too many colons in address",
},
}
for _, test := range tests {
tt := test
t.Run(tt.name, func(t *testing.T) {
hostAndPort, err := hostAndPortWithDefaultPort(tt.hostAndPort, tt.defaultPort)
if tt.wantError != "" {
require.EqualError(t, err, tt.wantError)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantHostAndPort, hostAndPort)
})
}
}
// Test various cases of host and port parsing.
func TestHostWithoutPort(t *testing.T) {
tests := []struct {
name string
hostAndPort string
wantError string
wantHostAndPort string
}{
{
name: "host already has port",
hostAndPort: "host.example.com:99",
wantHostAndPort: "host.example.com",
},
{
name: "host does not have port",
hostAndPort: "host.example.com",
wantHostAndPort: "host.example.com",
},
{
name: "IPv6 host already has port",
hostAndPort: "[::1%lo0]:80",
wantHostAndPort: "[::1%lo0]",
},
{
name: "IPv6 host does not have port",
hostAndPort: "[::1%lo0]",
wantHostAndPort: "[::1%lo0]",
},
{
name: "host is not valid",
hostAndPort: "host.example.com:port1:port2",
wantError: "address host.example.com:port1:port2: too many colons in address",
},
}
for _, test := range tests {
tt := test
t.Run(tt.name, func(t *testing.T) {
hostAndPort, err := hostWithoutPort(tt.hostAndPort)
if tt.wantError != "" {
require.EqualError(t, err, tt.wantError)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantHostAndPort, hostAndPort)
})
}
}

View File

@ -344,7 +344,7 @@ func TestLDAPSearch(t *testing.T) {
username: "pinny",
password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`,
},
{
name: "when the server is not parsable with StartTLS",
@ -355,7 +355,7 @@ func TestLDAPSearch(t *testing.T) {
p.ConnectionProtocol = upstreamldap.StartTLS
p.Host = "too:many:ports"
})),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": address too:many:ports: too many colons in address`,
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`,
},
{
name: "when the CA bundle is not parsable with TLS",