diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 8e62f724..eca83707 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -14,10 +14,13 @@ import ( "github.com/go-ldap/ldap/v3" "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/apiserver/pkg/authentication/user" ) const ( - ldapsScheme = "ldaps" + ldapsScheme = "ldaps" + distinguishedNameAttributeName = "dn" + userSearchFilterInterpolationLocationMarker = "{}" ) // Conn abstracts the upstream LDAP communication protocol (mostly for testing). @@ -158,25 +161,152 @@ func (p *Provider) GetURL() string { return fmt.Sprintf("%s://%s", ldapsScheme, p.Host) } -// TestConnection provides a method for testing the connection and bind settings by dialing and binding. -func (p *Provider) TestConnection(ctx context.Context) error { +// TestConnection provides a method for testing the connection and bind settings. It performs a dial and bind +// and returns any errors that we encountered. +func (p *Provider) TestConnection(ctx context.Context) (*authenticator.Response, error) { _, _ = p.dial(ctx) - // TODO bind using the bind credentials - // TODO close - // TODO return any dial or bind errors - return nil + // TODO implement me + return nil, nil +} + +// TestAuthenticateUser provides a method for testing all of the Provider settings in a kind of dry run of +// authentication. It runs the same logic as AuthenticateUser except it does not bind as that user, so it does not test +// their password. It returns the same authenticator.Response values and the same errors that a real call to +// AuthenticateUser with the correct password would return. +func (p *Provider) TestAuthenticateUser(ctx context.Context, testUsername string) (*authenticator.Response, error) { + // TODO implement me + return nil, nil } // Authenticate a user and return their mapped username, groups, and UID. Implements authenticators.UserAuthenticator. func (p *Provider) AuthenticateUser(ctx context.Context, username, password string) (*authenticator.Response, bool, error) { - _, _ = p.dial(ctx) - // TODO bind - // TODO user search - // TODO user bind - // TODO map username and uid attributes - // TODO group search - // TODO map group attributes - // TODO close - // TODO return any errors that were encountered along the way - return nil, false, nil + conn, err := p.dial(ctx) + if err != nil { + return nil, false, fmt.Errorf(`error dialing host "%s": %w`, p.Host, err) + } + defer conn.Close() + + err = conn.Bind(p.BindUsername, p.BindPassword) + if err != nil { + // TODO test this + return nil, false, fmt.Errorf(`error binding as "%s" before user search: %w`, p.BindUsername, err) + } + + mappedUsername, mappedUID, err := p.searchAndBindUser(conn, username, password) + if err != nil { + return nil, false, err + } + + response := &authenticator.Response{ + User: &user.DefaultInfo{ + Name: mappedUsername, + UID: mappedUID, + Groups: []string{}, // Support for group search coming soon. + }, + } + return response, true, nil +} + +func (p *Provider) searchAndBindUser(conn Conn, username string, password string) (string, string, error) { + searchResult, err := conn.Search(p.userSearchRequest(username)) + if err != nil { + // TODO test this + return "", "", fmt.Errorf(`error searching for user "%s": %w`, username, err) + } + if len(searchResult.Entries) != 1 { + // TODO test this + return "", "", fmt.Errorf(`searching for user "%s" resulted in %d search results, but expected 1 result`, + username, len(searchResult.Entries), + ) + } + userEntry := searchResult.Entries[0] + if len(userEntry.DN) == 0 { + // TODO test this + return "", "", fmt.Errorf(`searching for user "%s" resulted in search result without DN`, username) + } + + mappedUsername, err := p.getSearchResultAttributeValue(p.UserSearch.UsernameAttribute, userEntry, username) + if err != nil { + // TODO test this + return "", "", err + } + + mappedUID, err := p.getSearchResultAttributeValue(p.UserSearch.UIDAttribute, userEntry, username) + if err != nil { + // TODO test this + return "", "", err + } + + // Take care that any other LDAP commands after this bind will be run as this user instead of as the configured BindUsername! + err = conn.Bind(userEntry.DN, password) + if err != nil { + // TODO test this + return "", "", fmt.Errorf(`error binding for user "%s" using provided password against DN "%s": %w`, username, userEntry.DN, err) + } + + return mappedUsername, mappedUID, nil +} + +func (p *Provider) userSearchRequest(username string) *ldap.SearchRequest { + // See https://ldap.com/the-ldap-search-operation for general documentation of LDAP search options. + return &ldap.SearchRequest{ + BaseDN: p.UserSearch.Base, + Scope: ldap.ScopeWholeSubtree, + DerefAliases: ldap.DerefAlways, // TODO what's the best value here? + SizeLimit: 2, + TimeLimit: 90, + TypesOnly: false, + Filter: p.userSearchFilter(username), + Attributes: p.userSearchRequestedAttributes(), + Controls: nil, // this could be used to enable paging, but we're already limiting the result max size + } +} + +func (p *Provider) userSearchRequestedAttributes() []string { + attributes := []string{} + if p.UserSearch.UsernameAttribute != distinguishedNameAttributeName { + attributes = append(attributes, p.UserSearch.UsernameAttribute) + } + if p.UserSearch.UIDAttribute != distinguishedNameAttributeName { + attributes = append(attributes, p.UserSearch.UIDAttribute) + } + return attributes +} + +func (p *Provider) userSearchFilter(username string) string { + safeUsername := p.escapeUsernameForSearchFilter(username) + if len(p.UserSearch.Filter) == 0 { + return fmt.Sprintf("%s=%s", p.UserSearch.UsernameAttribute, safeUsername) + } + return strings.ReplaceAll(p.UserSearch.Filter, userSearchFilterInterpolationLocationMarker, safeUsername) +} + +func (p *Provider) escapeUsernameForSearchFilter(username string) string { + // The username is end-user input, so it should be escaped before being included in a search to prevent query injection. + return ldap.EscapeFilter(username) +} + +func (p *Provider) getSearchResultAttributeValue(attributeName string, fromUserEntry *ldap.Entry, username string) (string, error) { + if attributeName == distinguishedNameAttributeName { + return fromUserEntry.DN, nil + } + + attributeValues := fromUserEntry.GetAttributeValues(attributeName) + + if len(attributeValues) != 1 { + // TODO test this + return "", fmt.Errorf(`found %d values for attribute "%s" while searching for user "%s", but expected 1 result`, + len(attributeValues), attributeName, username, + ) + } + + attributeValue := attributeValues[0] + if len(attributeValue) == 0 { + // TODO test this + return "", fmt.Errorf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, + attributeName, username, + ) + } + + return attributeValue, nil } diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 358582f7..9b17c537 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -6,6 +6,7 @@ package upstreamldap import ( "context" "crypto/tls" + "errors" "fmt" "net" "net/http" @@ -22,92 +23,274 @@ import ( "go.pinniped.dev/internal/testutil" ) +const ( + testHost = "ldap.example.com:8443" + testBindUsername = "some-bind-username" + testBindPassword = "some-bind-password" + testUpstreamUsername = "some-upstream-username" + testUpstreamPassword = "some-upstream-password" + testUserSearchBase = "some-upstream-base-dn" + testUserSearchFilter = "some-filter={}-and-more-filter={}" + testUserSearchUsernameAttribute = "some-upstream-username-attribute" + testUserSearchUIDAttribute = "some-upstream-uid-attribute" + testSearchResultDNValue = "some-upstream-user-dn" + testSearchResultUsernameAttributeValue = "some-upstream-username-value" + testSearchResultUIDAttributeValue = "some-upstream-uid-value" +) + var ( - upstreamUsername = "some-upstream-username" - upstreamPassword = "some-upstream-password" - upstreamGroups = []string{"some-upstream-group-0", "some-upstream-group-1"} - upstreamUID = "some-upstream-uid" + testUserSearchFilterInterpolated = fmt.Sprintf("some-filter=%s-and-more-filter=%s", testUpstreamUsername, testUpstreamUsername) ) func TestAuthenticateUser(t *testing.T) { - // Please the linter... - _ = upstreamGroups - _ = upstreamUID - t.Skip("TODO: make me pass!") + provider := func(editFunc func(p *Provider)) *Provider { + provider := &Provider{ + Host: testHost, + BindUsername: testBindUsername, + BindPassword: testBindPassword, + UserSearch: &UserSearch{ + Base: testUserSearchBase, + Filter: testUserSearchFilter, + UsernameAttribute: testUserSearchUsernameAttribute, + UIDAttribute: testUserSearchUIDAttribute, + }, + } + if editFunc != nil { + editFunc(provider) + } + return provider + } + + expectedSearch := func(editFunc func(r *ldap.SearchRequest)) *ldap.SearchRequest { + request := &ldap.SearchRequest{ + BaseDN: testUserSearchBase, + Scope: ldap.ScopeWholeSubtree, + DerefAliases: ldap.DerefAlways, + SizeLimit: 2, + TimeLimit: 90, + TypesOnly: false, + Filter: testUserSearchFilterInterpolated, + Attributes: []string{testUserSearchUsernameAttribute, testUserSearchUIDAttribute}, + Controls: nil, + } + if editFunc != nil { + editFunc(request) + } + return request + } tests := []struct { - name string - provider *Provider - wantError string - wantUnauthenticated bool - wantAuthResponse *authenticator.Response + name string + username string + password string + provider *Provider + setupMocks func(conn *mockldapconn.MockConn) + dialError error + wantError string + wantAuthResponse *authenticator.Response }{ { - name: "happy path", - provider: &Provider{ - Host: "ldap.example.com:8443", - BindUsername: upstreamUsername, - BindPassword: upstreamPassword, - UserSearch: &UserSearch{ - Base: "some-upstream-base-dn", - Filter: "some-filter", - UsernameAttribute: "some-upstream-username-attribute", - UIDAttribute: "some-upstream-uid-attribute", - }, + name: "happy path", + username: testUpstreamUsername, + password: testUpstreamPassword, + provider: provider(nil), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedSearch(nil)).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testSearchResultUsernameAttributeValue}), + ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testSearchResultUIDAttributeValue}), + }, + }, + }, + Referrals: []string{}, // note that we are not following referrals at this time + Controls: []ldap.Control{}, // TODO are there any response controls that we need to be able to handle? + }, nil).Times(1) + conn.EXPECT().Bind(testSearchResultDNValue, testUpstreamPassword).Times(1) + conn.EXPECT().Close().Times(1) }, wantAuthResponse: &authenticator.Response{ User: &user.DefaultInfo{ - Name: upstreamUsername, - Groups: upstreamGroups, - UID: upstreamUID, + Name: testSearchResultUsernameAttributeValue, + Groups: []string{}, // We don't support group search yet. Coming soon! + UID: testSearchResultUIDAttributeValue, }, }, }, + { + name: "when the UsernameAttribute is dn", + username: testUpstreamUsername, + password: testUpstreamPassword, + provider: provider(func(p *Provider) { + p.UserSearch.UsernameAttribute = "dn" + }), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) { + r.Attributes = []string{testUserSearchUIDAttribute} + })).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testSearchResultUIDAttributeValue}), + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Bind(testSearchResultDNValue, testUpstreamPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantAuthResponse: &authenticator.Response{ + User: &user.DefaultInfo{ + Name: testSearchResultDNValue, + Groups: []string{}, // We don't support group search yet. Coming soon! + UID: testSearchResultUIDAttributeValue, + }, + }, + }, + { + name: "when the UIDAttribute is dn", + username: testUpstreamUsername, + password: testUpstreamPassword, + provider: provider(func(p *Provider) { + p.UserSearch.UIDAttribute = "dn" + }), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) { + r.Attributes = []string{testUserSearchUsernameAttribute} + })).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testSearchResultUsernameAttributeValue}), + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Bind(testSearchResultDNValue, testUpstreamPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantAuthResponse: &authenticator.Response{ + User: &user.DefaultInfo{ + Name: testSearchResultUsernameAttributeValue, + Groups: []string{}, // We don't support group search yet. Coming soon! + UID: testSearchResultDNValue, + }, + }, + }, + { + name: "when Filter is blank it derives a search filter from the UsernameAttribute", + username: testUpstreamUsername, + password: testUpstreamPassword, + provider: provider(func(p *Provider) { + p.UserSearch.Filter = "" + }), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) { + r.Filter = testUserSearchUsernameAttribute + "=" + testUpstreamUsername + })).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testSearchResultUsernameAttributeValue}), + ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testSearchResultUIDAttributeValue}), + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Bind(testSearchResultDNValue, testUpstreamPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantAuthResponse: &authenticator.Response{ + User: &user.DefaultInfo{ + Name: testSearchResultUsernameAttributeValue, + Groups: []string{}, // We don't support group search yet. Coming soon! + UID: testSearchResultUIDAttributeValue, + }, + }, + }, + { + name: "when the username has special LDAP search filter characters then they must be properly escaped in the search filter", + username: `a&b|c(d)e\f*g`, + password: testUpstreamPassword, + provider: provider(nil), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedSearch(func(r *ldap.SearchRequest) { + r.Filter = fmt.Sprintf("some-filter=%s-and-more-filter=%s", `a&b|c\28d\29e\5cf\2ag`, `a&b|c\28d\29e\5cf\2ag`) + })).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testSearchResultUsernameAttributeValue}), + ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testSearchResultUIDAttributeValue}), + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Bind(testSearchResultDNValue, testUpstreamPassword).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantAuthResponse: &authenticator.Response{ + User: &user.DefaultInfo{ + Name: testSearchResultUsernameAttributeValue, + Groups: []string{}, // We don't support group search yet. Coming soon! + UID: testSearchResultUIDAttributeValue, + }, + }, + }, + // TODO are LDAP attribute names case sensitive? do we need any special handling for case? + { + name: "when dial fails", + provider: provider(nil), + dialError: errors.New("some dial error"), + wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), + }, } + for _, test := range tests { - test := test - t.Run(test.name, func(t *testing.T) { + tt := test + t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) + conn := mockldapconn.NewMockConn(ctrl) - conn.EXPECT().Bind(test.provider.BindUsername, test.provider.BindPassword).Times(1) - conn.EXPECT().Search(&ldap.SearchRequest{ - BaseDN: test.provider.UserSearch.Base, - Scope: 99, // TODO: what should this be? - DerefAliases: 99, // TODO: what should this be? - SizeLimit: 99, - TimeLimit: 99, // TODO: what should this be? - TypesOnly: true, // TODO: what should this be? - Filter: test.provider.UserSearch.Filter, - Attributes: []string{}, // TODO: what should this be? - Controls: []ldap.Control{}, // TODO: what should this be? - }).Return(&ldap.SearchResult{ - Entries: []*ldap.Entry{ - { - DN: "", // TODO: what should this be? - Attributes: []*ldap.EntryAttribute{}, // TODO: what should this be? - }, - }, - Referrals: []string{}, // TODO: what should this be? - Controls: []ldap.Control{}, // TODO: what should this be? - }, nil).Times(1) - conn.EXPECT().Close().Times(1) + if tt.setupMocks != nil { + tt.setupMocks(conn) + } dialWasAttempted := false - test.provider.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { + tt.provider.Dialer = LDAPDialerFunc(func(ctx context.Context, hostAndPort string) (Conn, error) { dialWasAttempted = true - require.Equal(t, test.provider.Host, hostAndPort) + require.Equal(t, tt.provider.Host, hostAndPort) + if tt.dialError != nil { + return nil, tt.dialError + } return conn, nil }) - authResponse, authenticated, err := test.provider.AuthenticateUser(context.Background(), upstreamUsername, upstreamPassword) + authResponse, authenticated, err := tt.provider.AuthenticateUser(context.Background(), tt.username, tt.password) + require.True(t, dialWasAttempted, "AuthenticateUser was supposed to try to dial, but didn't") - if test.wantError != "" { - require.EqualError(t, err, test.wantError) - return + + if tt.wantError != "" { + require.EqualError(t, err, tt.wantError) + require.False(t, authenticated) + require.Nil(t, authResponse) + } else { + require.NoError(t, err) + require.True(t, authenticated) + require.Equal(t, tt.wantAuthResponse, authResponse) } - require.Equal(t, !test.wantUnauthenticated, authenticated) - require.Equal(t, test.wantAuthResponse, authResponse) }) } }