diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 136fc023..ddee048b 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -715,7 +715,9 @@ func (p *Provider) groupSearchRequestedAttributes() []string { } func (p *Provider) userSearchFilter(username string) string { - safeUsername := p.escapeUsernameForSearchFilter(username) + // The username is end user input, so it should be escaped before being included in a search to prevent + // query injection. + safeUsername := p.escapeForSearchFilter(username) if len(p.c.UserSearch.Filter) == 0 { return fmt.Sprintf("(%s=%s)", p.c.UserSearch.UsernameAttribute, safeUsername) } @@ -723,10 +725,14 @@ func (p *Provider) userSearchFilter(username string) string { } func (p *Provider) groupSearchFilter(userDN string) string { + // The DN can contain characters that are considered special characters by LDAP searches, so it should be + // escaped before being included in the search filter to prevent bad search syntax. + // E.g. for the DN `CN=My User (Admin),OU=Users,OU=my,DC=my,DC=domain` we must escape the parens. + safeUserDN := p.escapeForSearchFilter(userDN) if len(p.c.GroupSearch.Filter) == 0 { - return fmt.Sprintf("(member=%s)", userDN) + return fmt.Sprintf("(member=%s)", safeUserDN) } - return interpolateSearchFilter(p.c.GroupSearch.Filter, userDN) + return interpolateSearchFilter(p.c.GroupSearch.Filter, safeUserDN) } func interpolateSearchFilter(filterFormat, valueToInterpolateIntoFilter string) string { @@ -737,9 +743,8 @@ func interpolateSearchFilter(filterFormat, valueToInterpolateIntoFilter string) return "(" + filter + ")" } -func (p *Provider) escapeUsernameForSearchFilter(username string) string { - // The username is end user input, so it should be escaped before being included in a search to prevent query injection. - return ldap.EscapeFilter(username) +func (p *Provider) escapeForSearchFilter(s string) string { + return ldap.EscapeFilter(s) } // Returns the (potentially) binary data of the attribute's value, base64 URL encoded. diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 311aa6ec..c8bdc395 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -479,7 +479,7 @@ func TestEndUserAuthentication(t *testing.T) { wantAuthResponse: expectedAuthResponse(nil), }, { - name: "when the username has special LDAP search filter characters then they must be properly escaped in the search filter, because the username is end-user input", + name: "when the username has special LDAP search filter characters then they must be properly escaped in the custom user search filter, because the username is end-user input", username: `a&b|c(d)e\f*g`, password: testUpstreamPassword, providerConfig: providerConfig(nil), @@ -497,6 +497,92 @@ func TestEndUserAuthentication(t *testing.T) { }, wantAuthResponse: expectedAuthResponse(nil), }, + { + name: "when the username has special LDAP search filter characters then they must be properly escaped in the default user search filter, because the username is end-user input", + username: `a&b|c(d)e\f*g`, + password: testUpstreamPassword, + providerConfig: providerConfig(func(p *ProviderConfig) { + p.UserSearch.Filter = "" + }), + searchMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(func(r *ldap.SearchRequest) { + r.Filter = fmt.Sprintf("(some-upstream-username-attribute=%s)", `a&b|c\28d\29e\5cf\2ag`) + })).Return(exampleUserSearchResult, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(nil), expectedGroupSearchPageSize). + Return(exampleGroupSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + bindEndUserMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testUserSearchResultDNValue, testUpstreamPassword).Times(1) + }, + wantAuthResponse: expectedAuthResponse(nil), + }, + { + name: "when the user search result DN has special LDAP search filter characters then they must be properly escaped in the custom group search filter", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(nil), + searchMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)). + Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: `result DN with * \ special characters ()`, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testUserSearchResultUsernameAttributeValue}), + ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testUserSearchResultUIDAttributeValue}), + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(func(r *ldap.SearchRequest) { + escapedDN := `result DN with \2a \5c special characters \28\29` + r.Filter = fmt.Sprintf("(some-group-filter=%s-and-more-filter=%s)", escapedDN, escapedDN) + }), expectedGroupSearchPageSize).Return(exampleGroupSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + bindEndUserMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(`result DN with * \ special characters ()`, testUpstreamPassword).Times(1) + }, + wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { + r.DN = `result DN with * \ special characters ()` + }), + }, + { + name: "when the user search result DN has special LDAP search filter characters then they must be properly escaped in the default group search filter", + username: testUpstreamUsername, + password: testUpstreamPassword, + providerConfig: providerConfig(func(p *ProviderConfig) { + p.GroupSearch.Filter = "" + }), + searchMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)). + Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: `result DN with * \ special characters ()`, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testUserSearchResultUsernameAttributeValue}), + ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testUserSearchResultUIDAttributeValue}), + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(func(r *ldap.SearchRequest) { + r.Filter = fmt.Sprintf("(member=%s)", `result DN with \2a \5c special characters \28\29`) + }), expectedGroupSearchPageSize).Return(exampleGroupSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + bindEndUserMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(`result DN with * \ special characters ()`, testUpstreamPassword).Times(1) + }, + wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { + r.DN = `result DN with * \ special characters ()` + }), + }, { name: "group names are sorted to make the result more stable/predictable", username: testUpstreamUsername,