Add a little more logic to ldap_upstream_watcher.go

This commit is contained in:
Ryan Richard 2021-04-12 11:23:08 -07:00
parent 05daa9eff5
commit 05571abb74
4 changed files with 119 additions and 38 deletions

View File

@ -31,7 +31,7 @@ type UpstreamLDAPIdentityProviderICache interface {
type ldapWatcherController struct { type ldapWatcherController struct {
cache UpstreamLDAPIdentityProviderICache cache UpstreamLDAPIdentityProviderICache
ldapDialFunc upstreamldap.LDAPDialerFunc ldapDialer upstreamldap.LDAPDialer
client pinnipedclientset.Interface client pinnipedclientset.Interface
ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer
secretInformer corev1informers.SecretInformer secretInformer corev1informers.SecretInformer
@ -40,7 +40,7 @@ type ldapWatcherController struct {
// NewLDAPUpstreamWatcherController instantiates a new controllerlib.Controller which will populate the provided UpstreamLDAPIdentityProviderICache. // NewLDAPUpstreamWatcherController instantiates a new controllerlib.Controller which will populate the provided UpstreamLDAPIdentityProviderICache.
func NewLDAPUpstreamWatcherController( func NewLDAPUpstreamWatcherController(
idpCache UpstreamLDAPIdentityProviderICache, idpCache UpstreamLDAPIdentityProviderICache,
ldapDialFunc upstreamldap.LDAPDialerFunc, ldapDialer upstreamldap.LDAPDialer,
client pinnipedclientset.Interface, client pinnipedclientset.Interface,
ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer, ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer,
secretInformer corev1informers.SecretInformer, secretInformer corev1informers.SecretInformer,
@ -48,7 +48,7 @@ func NewLDAPUpstreamWatcherController(
) controllerlib.Controller { ) controllerlib.Controller {
c := ldapWatcherController{ c := ldapWatcherController{
cache: idpCache, cache: idpCache,
ldapDialFunc: ldapDialFunc, ldapDialer: ldapDialer,
client: client, client: client,
ldapIdentityProviderInformer: ldapIdentityProviderInformer, ldapIdentityProviderInformer: ldapIdentityProviderInformer,
secretInformer: secretInformer, secretInformer: secretInformer,
@ -93,5 +93,44 @@ func (c *ldapWatcherController) Sync(ctx controllerlib.Context) error {
} }
func (c *ldapWatcherController) validateUpstream(upstream *v1alpha1.LDAPIdentityProvider) provider.UpstreamLDAPIdentityProviderI { func (c *ldapWatcherController) validateUpstream(upstream *v1alpha1.LDAPIdentityProvider) provider.UpstreamLDAPIdentityProviderI {
return &upstreamldap.Provider{Name: upstream.Name, Dial: c.ldapDialFunc} spec := upstream.Spec
result := &upstreamldap.Provider{
Name: upstream.Name,
Host: spec.Host,
CABundle: []byte(spec.TLS.CertificateAuthorityData),
UserSearch: &upstreamldap.UserSearch{
Base: spec.UserSearch.Base,
Filter: spec.UserSearch.Filter,
UsernameAttribute: spec.UserSearch.Attributes.Username,
UIDAttribute: spec.UserSearch.Attributes.UniqueID,
},
Dialer: c.ldapDialer,
}
_ = c.validateSecret(upstream, result)
return result
}
func (c ldapWatcherController) validateSecret(upstream *v1alpha1.LDAPIdentityProvider, result *upstreamldap.Provider) *v1alpha1.Condition {
secretName := upstream.Spec.Bind.SecretName
secret, err := c.secretInformer.Lister().Secrets(upstream.Namespace).Get(secretName)
if err != nil {
// TODO
return nil
}
if secret.Type != corev1.SecretTypeBasicAuth {
// TODO
return nil
}
result.BindUsername = string(secret.Data[corev1.BasicAuthUsernameKey])
result.BindPassword = string(secret.Data[corev1.BasicAuthPasswordKey])
if len(result.BindUsername) == 0 || len(result.BindPassword) == 0 {
// TODO
return nil
}
var cond *v1alpha1.Condition // satisfy linter
return cond
} }

View File

@ -127,23 +127,49 @@ func TestLDAPUpstreamWatcherControllerFilterLDAPIdentityProviders(t *testing.T)
} }
} }
// Wrap the func into a struct so the test can do deep equal assertions on instances of upstreamldap.Provider.
type comparableDialer struct {
f upstreamldap.LDAPDialerFunc
}
func (d *comparableDialer) Dial(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) {
return d.f(ctx, hostAndPort)
}
func TestLDAPUpstreamWatcherControllerSync(t *testing.T) { func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
t.Parallel() t.Parallel()
var ( const (
testNamespace = "test-namespace" testNamespace = "test-namespace"
testName = "test-name" testName = "test-name"
testSecretName = "test-client-secret" testSecretName = "test-client-secret"
testBindUsername = "test-bind-username" testBindUsername = "test-bind-username"
testBindPassword = "test-bind-password" testBindPassword = "test-bind-password"
testHost = "ldap.example.com:123"
testCABundle = "test-ca-bundle"
testUserSearchBase = "test-user-search-base"
testUserSearchFilter = "test-user-search-filter"
testUsernameAttrName = "test-username-attr"
testUIDAttrName = "test-uid-attr"
)
var (
testValidSecretData = map[string][]byte{"username": []byte(testBindUsername), "password": []byte(testBindPassword)} testValidSecretData = map[string][]byte{"username": []byte(testBindUsername), "password": []byte(testBindPassword)}
) )
successfulDialer := &comparableDialer{
f: func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) {
// TODO return a fake implementation of upstreamldap.Conn, or return an error for testing errors
return nil, nil
},
}
tests := []struct { tests := []struct {
name string name string
inputUpstreams []runtime.Object inputUpstreams []runtime.Object
inputSecrets []runtime.Object inputSecrets []runtime.Object
ldapDialer upstreamldap.LDAPDialer
wantErr string wantErr string
wantResultingCache []provider.UpstreamLDAPIdentityProviderI wantResultingCache []*upstreamldap.Provider
wantResultingUpstreams []v1alpha1.LDAPIdentityProvider wantResultingUpstreams []v1alpha1.LDAPIdentityProvider
}{ }{
{ {
@ -151,18 +177,19 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
}, },
{ {
name: "one valid upstream updates the cache to include only that upstream", name: "one valid upstream updates the cache to include only that upstream",
ldapDialer: successfulDialer,
inputUpstreams: []runtime.Object{&v1alpha1.LDAPIdentityProvider{ inputUpstreams: []runtime.Object{&v1alpha1.LDAPIdentityProvider{
ObjectMeta: metav1.ObjectMeta{Name: testName, Namespace: testNamespace, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Name: testName, Namespace: testNamespace, Generation: 1234},
Spec: v1alpha1.LDAPIdentityProviderSpec{ Spec: v1alpha1.LDAPIdentityProviderSpec{
Host: "TODO", // TODO Host: testHost,
TLS: &v1alpha1.LDAPIdentityProviderTLSSpec{CertificateAuthorityData: "TODO"}, // TODO TLS: &v1alpha1.LDAPIdentityProviderTLSSpec{CertificateAuthorityData: testCABundle},
Bind: v1alpha1.LDAPIdentityProviderBindSpec{SecretName: testSecretName}, Bind: v1alpha1.LDAPIdentityProviderBindSpec{SecretName: testSecretName},
UserSearch: v1alpha1.LDAPIdentityProviderUserSearchSpec{ UserSearch: v1alpha1.LDAPIdentityProviderUserSearchSpec{
Base: "TODO", // TODO Base: testUserSearchBase,
Filter: "TODO", // TODO Filter: testUserSearchFilter,
Attributes: v1alpha1.LDAPIdentityProviderUserSearchAttributesSpec{ Attributes: v1alpha1.LDAPIdentityProviderUserSearchAttributesSpec{
Username: "TODO", // TODO Username: testUsernameAttrName,
UniqueID: "TODO", // TODO UniqueID: testUIDAttrName,
}, },
}, },
}, },
@ -172,10 +199,20 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
Type: corev1.SecretTypeBasicAuth, Type: corev1.SecretTypeBasicAuth,
Data: testValidSecretData, Data: testValidSecretData,
}}, }},
wantResultingCache: []provider.UpstreamLDAPIdentityProviderI{ wantResultingCache: []*upstreamldap.Provider{
&upstreamldap.Provider{ {
Name: testName, Name: testName,
// TODO test more stuff Host: testHost,
CABundle: []byte(testCABundle),
BindUsername: testBindUsername,
BindPassword: testBindPassword,
UserSearch: &upstreamldap.UserSearch{
Base: testUserSearchBase,
Filter: testUserSearchFilter,
UsernameAttribute: testUsernameAttrName,
UIDAttribute: testUIDAttrName,
},
Dialer: successfulDialer, // the dialer passed to the controller's constructor should have been passed through
}, },
}, },
wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{ wantResultingUpstreams: []v1alpha1.LDAPIdentityProvider{{
@ -202,10 +239,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
controller := NewLDAPUpstreamWatcherController( controller := NewLDAPUpstreamWatcherController(
cache, cache,
func(ctx context.Context, hostAndPort string) (upstreamldap.Conn, error) { successfulDialer,
// TODO return a fake implementation of upstreamldap.Conn, or return an error for testing errors
return nil, nil
},
fakePinnipedClient, fakePinnipedClient,
pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders(), pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders(),
kubeInformers.Core().V1().Secrets(), kubeInformers.Core().V1().Secrets(),
@ -231,8 +265,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
for i := range actualIDPList { for i := range actualIDPList {
actualIDP := actualIDPList[i].(*upstreamldap.Provider) actualIDP := actualIDPList[i].(*upstreamldap.Provider)
require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName()) require.Equal(t, tt.wantResultingCache[i], actualIDP)
// TODO more assertions
} }
actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{}) actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().LDAPIdentityProviders(testNamespace).List(ctx, metav1.ListOptions{})

View File

@ -12,7 +12,7 @@ import (
"net" "net"
"strings" "strings"
ldap "github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"k8s.io/apiserver/pkg/authentication/authenticator" "k8s.io/apiserver/pkg/authentication/authenticator"
) )
@ -32,9 +32,18 @@ type Conn interface {
// Our Conn type is subset of the ldap.Client interface, which is implemented by ldap.Conn. // Our Conn type is subset of the ldap.Client interface, which is implemented by ldap.Conn.
var _ Conn = &ldap.Conn{} var _ Conn = &ldap.Conn{}
// LDAPDialerFunc is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP. // LDAPDialer is a factory of Conn, and the resulting Conn can then be used to interact with an upstream LDAP IDP.
type LDAPDialer interface {
Dial(ctx context.Context, hostAndPort string) (Conn, error)
}
// LDAPDialerFunc makes it easy to use a func as an LDAPDialer.
type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error) type LDAPDialerFunc func(ctx context.Context, hostAndPort string) (Conn, error)
func (f LDAPDialerFunc) Dial(ctx context.Context, hostAndPort string) (Conn, error) {
return f(ctx, hostAndPort)
}
// Provider includes all of the settings for connection and searching for users and groups in // Provider includes all of the settings for connection and searching for users and groups in
// the upstream LDAP IDP. It also provides methods for testing the connection and performing logins. // the upstream LDAP IDP. It also provides methods for testing the connection and performing logins.
type Provider struct { type Provider struct {
@ -57,8 +66,8 @@ type Provider struct {
// UserSearch contains information about how to search for users in the upstream LDAP IDP. // UserSearch contains information about how to search for users in the upstream LDAP IDP.
UserSearch *UserSearch UserSearch *UserSearch
// Dial exists to enable testing. When nil, will use a default appropriate for production use. // Dialer exists to enable testing. When nil, will use a default appropriate for production use.
Dial LDAPDialerFunc Dialer LDAPDialer
} }
// UserSearch contains information about how to search for users in the upstream LDAP IDP. // UserSearch contains information about how to search for users in the upstream LDAP IDP.
@ -83,13 +92,13 @@ func (p *Provider) dial(ctx context.Context) (Conn, error) {
if err != nil { if err != nil {
return nil, ldap.NewError(ldap.ErrorNetwork, err) return nil, ldap.NewError(ldap.ErrorNetwork, err)
} }
if p.Dial != nil { if p.Dialer != nil {
return p.Dial(ctx, hostAndPort) return p.Dialer.Dial(ctx, hostAndPort)
} }
return p.dialTLS(ctx, hostAndPort) return p.dialTLS(ctx, hostAndPort)
} }
// dialTLS is the default implementation of the Dial func, used when Dial is nil. // dialTLS is the default implementation of the Dialer, used when Dialer is nil.
// Unfortunately, the go-ldap library does not seem to support dialing with a context.Context, // Unfortunately, the go-ldap library does not seem to support dialing with a context.Context,
// so we implement it ourselves, heavily inspired by ldap.DialURL. // so we implement it ourselves, heavily inspired by ldap.DialURL.
func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) { func (p *Provider) dialTLS(ctx context.Context, hostAndPort string) (Conn, error) {

View File

@ -94,11 +94,11 @@ func TestAuthenticateUser(t *testing.T) {
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
dialWasAttempted := false dialWasAttempted := false
test.provider.Dial = func(ctx context.Context, hostAndPort string) (Conn, error) { test.provider.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) {
dialWasAttempted = true dialWasAttempted = true
require.Equal(t, test.provider.Host, hostAndPort) require.Equal(t, test.provider.Host, hostAndPort)
return conn, nil return conn, nil
} })
authResponse, authenticated, err := test.provider.AuthenticateUser(context.Background(), upstreamUsername, upstreamPassword) authResponse, authenticated, err := test.provider.AuthenticateUser(context.Background(), upstreamUsername, upstreamPassword)
require.True(t, dialWasAttempted, "AuthenticateUser was supposed to try to dial, but didn't") require.True(t, dialWasAttempted, "AuthenticateUser was supposed to try to dial, but didn't")
@ -181,7 +181,7 @@ func TestRealTLSDialing(t *testing.T) {
provider := &Provider{ provider := &Provider{
Host: test.host, Host: test.host,
CABundle: test.caBundle, CABundle: test.caBundle,
Dial: nil, // this test is for the default (production) dialer Dialer: nil, // this test is for the default (production) dialer
} }
conn, err := provider.dial(test.context) conn, err := provider.dial(test.context)
if conn != nil { if conn != nil {
@ -198,7 +198,7 @@ func TestRealTLSDialing(t *testing.T) {
// Can't test its methods here because we are not dialed to a real LDAP server. // Can't test its methods here because we are not dialed to a real LDAP server.
require.IsType(t, &ldap.Conn{}, conn) require.IsType(t, &ldap.Conn{}, conn)
// Indirectly checking that the Dial method constructed the ldap.Conn with isTLS set to true, // Indirectly checking that the Dialer method constructed the ldap.Conn with isTLS set to true,
// since this is always the correct behavior unless/until we want to support StartTLS. // since this is always the correct behavior unless/until we want to support StartTLS.
err := conn.(*ldap.Conn).StartTLS(&tls.Config{}) err := conn.(*ldap.Conn).StartTLS(&tls.Config{})
require.EqualError(t, err, `LDAP Result Code 200 "Network Error": ldap: already encrypted`) require.EqualError(t, err, `LDAP Result Code 200 "Network Error": ldap: already encrypted`)