More LDAP WIP: started controller and LDAP server connection code

Both are unfinished works in progress.
This commit is contained in:
Ryan Richard 2021-04-09 18:49:43 -07:00
parent 7781a2e17a
commit 05daa9eff5
8 changed files with 668 additions and 34 deletions

View File

@ -241,6 +241,17 @@ func startControllers(
klogr.New(), klogr.New(),
controllerlib.WithInformer, controllerlib.WithInformer,
), ),
singletonWorker).
WithController(
upstreamwatcher.NewLDAPUpstreamWatcherController(
dynamicUpstreamIDPProvider,
// nil means to use a real production dialer when creating objects to add to the dynamicUpstreamIDPProvider cache.
nil,
pinnipedClient,
pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders(),
secretInformer,
controllerlib.WithInformer,
),
singletonWorker) singletonWorker)
kubeInformers.Start(ctx.Done()) kubeInformers.Start(ctx.Done())

View File

@ -1,8 +1,8 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved. // Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package ldap contains common LDAP functionality needed by Pinniped. // Package authenticators contains authenticator interfaces.
package ldap package authenticators
import ( import (
"context" "context"

View File

@ -0,0 +1,97 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package upstreamwatcher
import (
"fmt"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/labels"
corev1informers "k8s.io/client-go/informers/core/v1"
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
pinnipedclientset "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned"
idpinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions/idp/v1alpha1"
pinnipedcontroller "go.pinniped.dev/internal/controller"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/upstreamldap"
)
const (
ldapControllerName = "ldap-upstream-observer"
ldapBindAccountSecretType = corev1.SecretTypeBasicAuth
)
// UpstreamLDAPIdentityProviderICache is a thread safe cache that holds a list of validated upstream LDAP IDP configurations.
type UpstreamLDAPIdentityProviderICache interface {
SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI)
}
type ldapWatcherController struct {
cache UpstreamLDAPIdentityProviderICache
ldapDialFunc upstreamldap.LDAPDialerFunc
client pinnipedclientset.Interface
ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer
secretInformer corev1informers.SecretInformer
}
// NewLDAPUpstreamWatcherController instantiates a new controllerlib.Controller which will populate the provided UpstreamLDAPIdentityProviderICache.
func NewLDAPUpstreamWatcherController(
idpCache UpstreamLDAPIdentityProviderICache,
ldapDialFunc upstreamldap.LDAPDialerFunc,
client pinnipedclientset.Interface,
ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer,
secretInformer corev1informers.SecretInformer,
withInformer pinnipedcontroller.WithInformerOptionFunc,
) controllerlib.Controller {
c := ldapWatcherController{
cache: idpCache,
ldapDialFunc: ldapDialFunc,
client: client,
ldapIdentityProviderInformer: ldapIdentityProviderInformer,
secretInformer: secretInformer,
}
return controllerlib.New(
controllerlib.Config{Name: ldapControllerName, Syncer: &c},
withInformer(
ldapIdentityProviderInformer,
pinnipedcontroller.MatchAnythingFilter(pinnipedcontroller.SingletonQueue()),
controllerlib.InformerOption{},
),
withInformer(
secretInformer,
pinnipedcontroller.MatchAnySecretOfTypeFilter(ldapBindAccountSecretType, pinnipedcontroller.SingletonQueue()),
controllerlib.InformerOption{},
),
)
}
// Sync implements controllerlib.Syncer.
func (c *ldapWatcherController) Sync(ctx controllerlib.Context) error {
actualUpstreams, err := c.ldapIdentityProviderInformer.Lister().List(labels.Everything())
if err != nil {
return fmt.Errorf("failed to list LDAPIdentityProviders: %w", err)
}
requeue := false
validatedUpstreams := make([]provider.UpstreamLDAPIdentityProviderI, 0, len(actualUpstreams))
for _, upstream := range actualUpstreams {
valid := c.validateUpstream(upstream)
if valid == nil {
requeue = true
} else {
validatedUpstreams = append(validatedUpstreams, valid)
}
}
c.cache.SetLDAPIdentityProviders(validatedUpstreams)
if requeue {
return controllerlib.ErrSyntheticRequeue
}
return nil
}
func (c *ldapWatcherController) validateUpstream(upstream *v1alpha1.LDAPIdentityProvider) provider.UpstreamLDAPIdentityProviderI {
return &upstreamldap.Provider{Name: upstream.Name, Dial: c.ldapDialFunc}
}

View File

@ -0,0 +1,253 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package upstreamwatcher
import (
"context"
"testing"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes/fake"
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
pinnipedfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/upstreamldap"
)
func TestLDAPUpstreamWatcherControllerFilterSecrets(t *testing.T) {
t.Parallel()
tests := []struct {
name string
secret metav1.Object
wantAdd bool
wantUpdate bool
wantDelete bool
}{
{
name: "a secret of the right type",
secret: &corev1.Secret{
Type: corev1.SecretTypeBasicAuth,
ObjectMeta: metav1.ObjectMeta{Name: "some-name", Namespace: "some-namespace"},
},
wantAdd: true,
wantUpdate: true,
wantDelete: true,
},
{
name: "a secret of the wrong type",
secret: &corev1.Secret{
Type: "this-is-the-wrong-type",
ObjectMeta: metav1.ObjectMeta{Name: "some-name", Namespace: "some-namespace"},
},
},
{
name: "resource of a data type which is not watched by this controller",
secret: &corev1.Namespace{
ObjectMeta: metav1.ObjectMeta{Name: "some-name", Namespace: "some-namespace"},
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fakePinnipedClient := pinnipedfake.NewSimpleClientset()
pinnipedInformers := pinnipedinformers.NewSharedInformerFactory(fakePinnipedClient, 0)
ldapIDPInformer := pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders()
fakeKubeClient := fake.NewSimpleClientset()
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
secretInformer := kubeInformers.Core().V1().Secrets()
withInformer := testutil.NewObservableWithInformerOption()
NewLDAPUpstreamWatcherController(nil, nil, nil, ldapIDPInformer, secretInformer, withInformer.WithInformer)
unrelated := corev1.Secret{}
filter := withInformer.GetFilterForInformer(secretInformer)
require.Equal(t, test.wantAdd, filter.Add(test.secret))
require.Equal(t, test.wantUpdate, filter.Update(&unrelated, test.secret))
require.Equal(t, test.wantUpdate, filter.Update(test.secret, &unrelated))
require.Equal(t, test.wantDelete, filter.Delete(test.secret))
})
}
}
func TestLDAPUpstreamWatcherControllerFilterLDAPIdentityProviders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
idp metav1.Object
wantAdd bool
wantUpdate bool
wantDelete bool
}{
{
name: "any LDAPIdentityProvider",
idp: &v1alpha1.LDAPIdentityProvider{
ObjectMeta: metav1.ObjectMeta{Name: "some-name", Namespace: "some-namespace"},
},
wantAdd: true,
wantUpdate: true,
wantDelete: true,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fakePinnipedClient := pinnipedfake.NewSimpleClientset()
pinnipedInformers := pinnipedinformers.NewSharedInformerFactory(fakePinnipedClient, 0)
ldapIDPInformer := pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders()
fakeKubeClient := fake.NewSimpleClientset()
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
secretInformer := kubeInformers.Core().V1().Secrets()
withInformer := testutil.NewObservableWithInformerOption()
NewLDAPUpstreamWatcherController(nil, nil, nil, ldapIDPInformer, secretInformer, withInformer.WithInformer)
unrelated := corev1.Secret{}
filter := withInformer.GetFilterForInformer(ldapIDPInformer)
require.Equal(t, test.wantAdd, filter.Add(test.idp))
require.Equal(t, test.wantUpdate, filter.Update(&unrelated, test.idp))
require.Equal(t, test.wantUpdate, filter.Update(test.idp, &unrelated))
require.Equal(t, test.wantDelete, filter.Delete(test.idp))
})
}
}
func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
t.Parallel()
var (
testNamespace = "test-namespace"
testName = "test-name"
testSecretName = "test-client-secret"
testBindUsername = "test-bind-username"
testBindPassword = "test-bind-password"
testValidSecretData = map[string][]byte{"username": []byte(testBindUsername), "password": []byte(testBindPassword)}
)
tests := []struct {
name string
inputUpstreams []runtime.Object
inputSecrets []runtime.Object
wantErr string
wantResultingCache []provider.UpstreamLDAPIdentityProviderI
wantResultingUpstreams []v1alpha1.LDAPIdentityProvider
}{
{
name: "no LDAPIdentityProvider upstreams clears the cache",
},
{
name: "one valid upstream updates the cache to include only that upstream",
inputUpstreams: []runtime.Object{&v1alpha1.LDAPIdentityProvider{
ObjectMeta: metav1.ObjectMeta{Name: testName, Namespace: testNamespace, Generation: 1234},
Spec: v1alpha1.LDAPIdentityProviderSpec{
Host: "TODO", // TODO
TLS: &v1alpha1.LDAPIdentityProviderTLSSpec{CertificateAuthorityData: "TODO"}, // TODO
Bind: v1alpha1.LDAPIdentityProviderBindSpec{SecretName: testSecretName},
UserSearch: v1alpha1.LDAPIdentityProviderUserSearchSpec{
Base: "TODO", // TODO
Filter: "TODO", // TODO
Attributes: v1alpha1.LDAPIdentityProviderUserSearchAttributesSpec{
Username: "TODO", // TODO
UniqueID: "TODO", // TODO
},
},
},
}},
inputSecrets: []runtime.Object{&corev1.Secret{
ObjectMeta: metav1.ObjectMeta{Name: testSecretName, Namespace: testNamespace},
Type: corev1.SecretTypeBasicAuth,
Data: testValidSecretData,
}},
wantResultingCache: []provider.UpstreamLDAPIdentityProviderI{
&upstreamldap.Provider{
Name: testName,
// TODO test more stuff
},
},
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.LDAPIdentityProviderStatus{
Phase: "Ready",
// TODO Conditions
},
}},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
fakePinnipedClient := pinnipedfake.NewSimpleClientset(tt.inputUpstreams...)
pinnipedInformers := pinnipedinformers.NewSharedInformerFactory(fakePinnipedClient, 0)
fakeKubeClient := fake.NewSimpleClientset(tt.inputSecrets...)
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{
&upstreamldap.Provider{Name: "initial-entry"},
})
controller := NewLDAPUpstreamWatcherController(
cache,
func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) {
// TODO return a fake implementation of upstreamldap.Conn, or return an error for testing errors
return nil, nil
},
fakePinnipedClient,
pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders(),
kubeInformers.Core().V1().Secrets(),
controllerlib.WithInformer,
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
pinnipedInformers.Start(ctx.Done())
kubeInformers.Start(ctx.Done())
controllerlib.TestRunSynchronously(t, controller)
syncCtx := controllerlib.Context{Context: ctx, Key: controllerlib.Key{}}
if err := controllerlib.TestSync(t, controller, syncCtx); tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
actualIDPList := cache.GetLDAPIdentityProviders()
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
for i := range actualIDPList {
actualIDP := actualIDPList[i].(*upstreamldap.Provider)
require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName())
// TODO more assertions
}
actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{})
require.NoError(t, err)
// TODO maybe use something like the normalizeUpstreams() helper to make assertions about what was updated
_ = actualUpstreams
// require.ElementsMatch(t, tt.wantResultingUpstreams, actualUpstreams.Items)
// Running the sync() a second time should be idempotent, and should return the same error.
if err := controllerlib.TestSync(t, controller, syncCtx); tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
} else {
require.NoError(t, err)
}
})
}
}

View File

@ -31,7 +31,7 @@ import (
"go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/internal/upstreamoidc"
) )
func TestControllerFilterSecret(t *testing.T) { func TestOIDCUpstreamWatcherControllerFilterSecret(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
@ -101,7 +101,7 @@ func TestControllerFilterSecret(t *testing.T) {
} }
} }
func TestController(t *testing.T) { func TestOIDCUpstreamWatcherControllerSync(t *testing.T) {
t.Parallel() t.Parallel()
now := metav1.NewTime(time.Now().UTC()) now := metav1.NewTime(time.Now().UTC())
earlier := metav1.NewTime(now.Add(-1 * time.Hour).UTC()) earlier := metav1.NewTime(now.Add(-1 * time.Hour).UTC())

View File

@ -10,7 +10,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"go.pinniped.dev/internal/ldap" "go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
@ -59,7 +59,7 @@ type UpstreamLDAPIdentityProviderI interface {
GetURL() string GetURL() string
// A method for performing user authentication against the upstream LDAP provider. // A method for performing user authentication against the upstream LDAP provider.
ldap.UserAuthenticator authenticators.UserAuthenticator
} }
type DynamicUpstreamIDPProvider interface { type DynamicUpstreamIDPProvider interface {

View File

@ -6,65 +6,168 @@ package upstreamldap
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"strings"
ldap "github.com/go-ldap/ldap/v3" ldap "github.com/go-ldap/ldap/v3"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
) )
const (
ldapsScheme = "ldaps"
)
// Conn abstracts the upstream LDAP communication protocol (mostly for testing). // Conn abstracts the upstream LDAP communication protocol (mostly for testing).
type Conn interface { type Conn interface {
// Bind abstracts ldap.Conn.Bind().
Bind(username, password string) error Bind(username, password string) error
// Search abstracts ldap.Conn.Search().
Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
// Close abstracts ldap.Conn.Close().
Close() Close()
} }
// Our Conn type is subset of the ldap.Client interface, which is implemented by ldap.Conn.
var _ Conn = &ldap.Conn{}
// LDAPDialerFunc is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP.
type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error)
// Provider includes all of the settings for connection and searching for users and groups in
// the upstream LDAP IDP. It also provides methods for testing the connection and performing logins.
type Provider struct {
// Name is the unique name of this upstream LDAP IDP.
Name string
// Host is the hostname or "hostname:port" of the LDAP server. When the port is not specified,
// the default LDAP port will be used.
Host string
// PEM-encoded CA cert bundle to trust when connecting to the LDAP server.
CABundle []byte
// BindUsername is the username to use when performing a bind with the upstream LDAP IDP.
BindUsername string
// BindPassword is the password to use when performing a bind with the upstream LDAP IDP.
BindPassword string
// UserSearch contains information about how to search for users in the upstream LDAP IDP.
UserSearch *UserSearch
// Dial exists to enable testing. When nil, will use a default appropriate for production use.
Dial LDAPDialerFunc
}
// UserSearch contains information about how to search for users in the upstream LDAP IDP. // UserSearch contains information about how to search for users in the upstream LDAP IDP.
type UserSearch struct { type UserSearch struct {
// Base is the base DN to use for the user search in the upstream LDAP IDP. // Base is the base DN to use for the user search in the upstream LDAP IDP.
Base string Base string
// Filter is the filter to use for the user search in the upstream LDAP IDP. // Filter is the filter to use for the user search in the upstream LDAP IDP.
Filter string Filter string
// UsernameAttribute is the attribute in the LDAP entry from which the username should be // UsernameAttribute is the attribute in the LDAP entry from which the username should be
// retrieved. // retrieved.
UsernameAttribute string UsernameAttribute string
// UIDAttribute is the attribute in the LDAP entry from which the user's unique ID should be // UIDAttribute is the attribute in the LDAP entry from which the user's unique ID should be
// retrieved. // retrieved.
UIDAttribute string UIDAttribute string
} }
// Provider contains can interact with an upstream LDAP IDP. func (p *Provider) dial(ctx context.Context) (Conn, error) {
type Provider struct { hostAndPort, err := hostAndPortWithDefaultPort(p.Host, ldap.DefaultLdapsPort)
// Name is the unique name of this upstream LDAP IDP. if err != nil {
Name string return nil, ldap.NewError(ldap.ErrorNetwork, err)
// URL is the URL of this upstream LDAP IDP. }
URL string if p.Dial != nil {
return p.Dial(ctx, hostAndPort)
// Dial is a func that, given a URL, will return an LDAPConn to use for communicating with an }
// upstream LDAP IDP. return p.dialTLS(ctx, hostAndPort)
Dial func(ctx context.Context, url string) (Conn, error)
// BindUsername is the username to use when performing a bind with the upstream LDAP IDP.
BindUsername string
// BindPassword is the password to use when performing a bind with the upstream LDAP IDP.
BindPassword string
// UserSearch contains information about how to search for users in the upstream LDAP IDP.
UserSearch *UserSearch
} }
// dialTLS is the default implementation of the Dial func, used when Dial is nil.
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
// so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {
rootCAs := x509.NewCertPool()
if p.CABundle != nil {
if !rootCAs.AppendCertsFromPEM(p.CABundle) {
return nil, ldap.NewError(ldap.ErrorNetwork, fmt.Errorf("could not parse CA bundle"))
}
}
dialer := &tls.Dialer{Config: &tls.Config{
MinVersion: tls.VersionTLS12,
RootCAs: rootCAs,
}}
c, err := dialer.DialContext(ctx, "tcp", hostAndPort)
if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err)
}
conn := ldap.NewConn(c, true)
conn.Start()
return conn, nil
}
// Adds the default port if hostAndPort did not already include a port.
func hostAndPortWithDefaultPort(hostAndPort string, defaultPort string) (string, error) {
host, port, err := net.SplitHostPort(hostAndPort)
if err != nil {
if strings.HasSuffix(err.Error(), ": missing port in address") { // sad to need to do this string compare
host = hostAndPort
port = defaultPort
} else {
return "", err // hostAndPort argument was not parsable
}
}
switch {
case port != "" && strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
// don't add extra square brackets to an IPv6 address that already has them
return host + ":" + port, nil
case port != "":
return net.JoinHostPort(host, port), nil
default:
return host, nil
}
}
// A name for this upstream provider.
func (p *Provider) GetName() string { func (p *Provider) GetName() string {
return p.Name return p.Name
} }
// Return a URL which uniquely identifies this LDAP provider, e.g. "ldaps://host.example.com:1234".
// This URL is not used for connecting to the provider, but rather is used for creating a globally unique user
// identifier by being combined with the user's UID, since user UIDs are only unique within one provider.
func (p *Provider) GetURL() string { func (p *Provider) GetURL() string {
return p.URL return fmt.Sprintf("%s://%s", ldapsScheme, p.Host)
} }
// TestConnection provides a method for testing the connection and bind settings by dialing and binding.
func (p *Provider) TestConnection(ctx context.Context) error {
_, _ = p.dial(ctx)
// TODO bind using the bind credentials
// TODO close
// TODO return any dial or bind errors
return nil
}
// Authenticate a user and return their mapped username, groups, and UID. Implements authenticators.UserAuthenticator.
func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) { func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) {
// TODO: test context timeout? _, _ = p.dial(ctx)
// TODO: test dial context timeout? // TODO bind
// TODO user search
// TODO user bind
// TODO map username and uid attributes
// TODO group search
// TODO map group attributes
// TODO close
// TODO return any errors that were encountered along the way
return nil, false, nil return nil, false, nil
} }

View File

@ -5,15 +5,21 @@ package upstreamldap
import ( import (
"context" "context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"testing" "testing"
ldap "github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/mocks/mockldapconn"
"go.pinniped.dev/internal/testutil"
) )
var ( var (
@ -39,7 +45,7 @@ func TestAuthenticateUser(t *testing.T) {
{ {
name: "happy path", name: "happy path",
provider: &Provider{ provider: &Provider{
URL: "ldaps://some-ldap-url:1234", Host: "ldap.example.com:8443",
BindUsername: upstreamUsername, BindUsername: upstreamUsername,
BindPassword: upstreamPassword, BindPassword: upstreamPassword,
UserSearch: &UserSearch{ UserSearch: &UserSearch{
@ -87,12 +93,15 @@ func TestAuthenticateUser(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
test.provider.Dial = func(ctx context.Context, url string) (Conn, error) { dialWasAttempted := false
require.Equal(t, test.provider.URL, url) test.provider.Dial = func(ctx context.Context, hostAndPort string) (Conn, error) {
dialWasAttempted = true
require.Equal(t, test.provider.Host, hostAndPort)
return conn, nil return conn, nil
} }
authResponse, authenticated, err := test.provider.AuthenticateUser(context.Background(), upstreamUsername, upstreamPassword) authResponse, authenticated, err := test.provider.AuthenticateUser(context.Background(), upstreamUsername, upstreamPassword)
require.True(t, dialWasAttempted, "AuthenticateUser was supposed to try to dial, but didn't")
if test.wantError != "" { if test.wantError != "" {
require.EqualError(t, err, test.wantError) require.EqualError(t, err, test.wantError)
return return
@ -102,3 +111,164 @@ func TestAuthenticateUser(t *testing.T) {
}) })
} }
} }
func TestGetURL(t *testing.T) {
require.Equal(t, "ldaps://ldap.example.com:1234", (&Provider{Host: "ldap.example.com:1234"}).GetURL())
require.Equal(t, "ldaps://ldap.example.com", (&Provider{Host: "ldap.example.com"}).GetURL())
}
// Testing of host parsing, TLS negotiation, and CA bundle, etc. for the production code's dialer.
func TestRealTLSDialing(t *testing.T) {
testServerCABundle, testServerURL := testutil.TLSTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
parsedURL, err := url.Parse(testServerURL)
require.NoError(t, err)
testServerHostAndPort := parsedURL.Host
unusedPortGrabbingListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
recentlyClaimedHostAndPort := unusedPortGrabbingListener.Addr().String()
require.NoError(t, unusedPortGrabbingListener.Close())
alreadyCancelledContext, cancelFunc := context.WithCancel(context.Background())
cancelFunc() // cancel it immediately
tests := []struct {
name string
host string
caBundle []byte
context context.Context
wantError string
}{
{
name: "happy path",
host: testServerHostAndPort,
caBundle: []byte(testServerCABundle),
context: context.Background(),
},
{
name: "invalid CA bundle",
host: testServerHostAndPort,
caBundle: []byte("not a ca bundle"),
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`,
},
{
name: "missing CA bundle when it is required because the host is not using a trusted CA",
host: testServerHostAndPort,
caBundle: nil,
context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": x509: certificate signed by unknown authority`,
},
{
name: "cannot connect to host",
// This is assuming that this port was not reclaimed by another app since the test setup ran. Seems safe enough.
host: recentlyClaimedHostAndPort,
caBundle: []byte(testServerCABundle),
context: context.Background(),
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort),
},
{
name: "pays attention to the passed context",
host: testServerHostAndPort,
caBundle: []byte(testServerCABundle),
context: alreadyCancelledContext,
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
provider := &Provider{
Host: test.host,
CABundle: test.caBundle,
Dial: nil, // this test is for the default (production) dialer
}
conn, err := provider.dial(test.context)
if conn != nil {
defer conn.Close()
}
if test.wantError != "" {
require.Nil(t, conn)
require.EqualError(t, err, test.wantError)
} else {
require.NoError(t, err)
require.NotNil(t, conn)
// Should be an instance of the real production LDAP client type.
// Can't test its methods here because we are not dialed to a real LDAP server.
require.IsType(t, &ldap.Conn{}, conn)
// Indirectly checking that the Dial method constructed the ldap.Conn with isTLS set to true,
// since this is always the correct behavior unless/until we want to support StartTLS.
err := conn.(*ldap.Conn).StartTLS(&tls.Config{})
require.EqualError(t, err, `LDAP Result Code 200 "Network Error": ldap: already encrypted`)
}
})
}
}
// Test various cases of host and port parsing.
func TestHostAndPortWithDefaultPort(t *testing.T) {
tests := []struct {
name string
hostAndPort string
defaultPort string
wantError string
wantHostAndPort string
}{
{
name: "host already has port",
hostAndPort: "host.example.com:99",
defaultPort: "42",
wantHostAndPort: "host.example.com:99",
},
{
name: "host does not have port",
hostAndPort: "host.example.com",
defaultPort: "42",
wantHostAndPort: "host.example.com:42",
},
{
name: "host does not have port and default port is empty",
hostAndPort: "host.example.com",
defaultPort: "",
wantHostAndPort: "host.example.com",
},
{
name: "IPv6 host already has port",
hostAndPort: "[::1%lo0]:80",
defaultPort: "42",
wantHostAndPort: "[::1%lo0]:80",
},
{
name: "IPv6 host does not have port",
hostAndPort: "[::1%lo0]",
defaultPort: "42",
wantHostAndPort: "[::1%lo0]:42",
},
{
name: "IPv6 host does not have port and default port is empty",
hostAndPort: "[::1%lo0]",
defaultPort: "",
wantHostAndPort: "[::1%lo0]",
},
{
name: "host is not valid",
hostAndPort: "host.example.com:port1:port2",
defaultPort: "42",
wantError: "address host.example.com:port1:port2: too many colons in address",
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
hostAndPort, err := hostAndPortWithDefaultPort(test.hostAndPort, test.defaultPort)
if test.wantError != "" {
require.EqualError(t, err, test.wantError)
} else {
require.NoError(t, err)
}
require.Equal(t, test.wantHostAndPort, hostAndPort)
})
}
}