diff --git a/internal/fositestorage/authorizationcode/authorizationcode.go b/internal/fositestorage/authorizationcode/authorizationcode.go index b259e406..3418f672 100644 --- a/internal/fositestorage/authorizationcode/authorizationcode.go +++ b/internal/fositestorage/authorizationcode/authorizationcode.go @@ -328,14 +328,22 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ "providerType": "闣ʬ橳(ý綃ʃʚƟ覣k眐4Ĉt", "oidc": { "upstreamRefreshToken": "嵽痊w©Ź榨Q|ôɵt毇妬" + }, + "ldap": { + "userDN": "6鉢緋uƴŤȱʀļÂ?墖\u003cƬb獭潜Ʃ饾" + }, + "activedirectory": { + "userDN": "|鬌R蜚蠣麹概÷驣7Ʀ澉1æɽ誮rʨ鷞" } } }, "requestedAudience": [ - "6鉢緋uƴŤȱʀļÂ?墖\u003cƬb獭潜Ʃ饾" + "ŚB碠k9" ], "grantedAudience": [ - "|鬌R蜚蠣麹概÷驣7Ʀ澉1æɽ誮rʨ鷞" + "ʘ赱", + "ď逳鞪?3)藵睋邔\u0026Ű惫蜀Ģ¡圔", + "墀jMʥ" ] }, "version": "2" diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 4c457faf..d22f9ad7 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -112,6 +112,10 @@ func handleAuthRequestForLDAPUpstream( subject := downstreamSubjectFromUpstreamLDAP(ldapUpstream, authenticateResponse) username = authenticateResponse.User.GetName() groups := authenticateResponse.User.GetGroups() + dn := userDNFromAuthenticatedResponse(authenticateResponse) + if dn == "" { + return httperr.New(http.StatusInternalServerError, "unexpected error during upstream authentication") + } customSessionData := &psession.CustomSessionData{ ProviderUID: ldapUpstream.GetResourceUID(), @@ -119,6 +123,17 @@ func handleAuthRequestForLDAPUpstream( ProviderType: idpType, } + if idpType == psession.ProviderTypeLDAP { + customSessionData.LDAP = &psession.LDAPSessionData{ + UserDN: dn, + } + } + if idpType == psession.ProviderTypeActiveDirectory { + customSessionData.ActiveDirectory = &psession.ActiveDirectorySessionData{ + UserDN: dn, + } + } + return makeDownstreamSessionAndReturnAuthcodeRedirect(r, w, oauthHelper, authorizeRequester, subject, username, groups, customSessionData) } @@ -477,3 +492,16 @@ func downstreamSubjectFromUpstreamLDAP(ldapUpstream provider.UpstreamLDAPIdentit ldapURL.RawQuery = q.Encode() return ldapURL.String() } + +func userDNFromAuthenticatedResponse(authenticatedResponse *authenticator.Response) string { + // These errors shouldn't happen, but do some error checking anyway so it doesn't panic + extra := authenticatedResponse.User.GetExtra() + if len(extra) == 0 { + return "" + } + dnSlice := extra["userDN"] + if len(dnSlice) != 1 { + return "" + } + return dnSlice[0] +} diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 7f61c103..39264037 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -267,6 +267,7 @@ func TestAuthorizationEndpoint(t *testing.T) { happyLDAPUsernameFromAuthenticator := "some-mapped-ldap-username" happyLDAPPassword := "some-ldap-password" //nolint:gosec happyLDAPUID := "some-ldap-uid" + happyLDAPUserDN := "cn=foo,dn=bar" happyLDAPGroups := []string{"group1", "group2", "group3"} parsedUpstreamLDAPURL, err := url.Parse(upstreamLDAPURL) @@ -282,6 +283,7 @@ func TestAuthorizationEndpoint(t *testing.T) { Name: happyLDAPUsernameFromAuthenticator, UID: happyLDAPUID, Groups: happyLDAPGroups, + Extra: map[string][]string{"userDN": {happyLDAPUserDN}}, }, }, true, nil } @@ -438,6 +440,10 @@ func TestAuthorizationEndpoint(t *testing.T) { ProviderName: activeDirectoryUpstreamName, ProviderType: psession.ProviderTypeActiveDirectory, OIDC: nil, + LDAP: nil, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: happyLDAPUserDN, + }, } expectedHappyLDAPUpstreamCustomSession := &psession.CustomSessionData{ @@ -445,6 +451,10 @@ func TestAuthorizationEndpoint(t *testing.T) { ProviderName: ldapUpstreamName, ProviderType: psession.ProviderTypeLDAP, OIDC: nil, + LDAP: &psession.LDAPSessionData{ + UserDN: happyLDAPUserDN, + }, + ActiveDirectory: nil, } expectedHappyOIDCPasswordGrantCustomSession := &psession.CustomSessionData{ diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 88710f00..c5f2b7b1 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -88,6 +88,9 @@ type UpstreamLDAPIdentityProviderI interface { // UserAuthenticator adds an interface method for performing user authentication against the upstream LDAP provider. authenticators.UserAuthenticator + + // PerformRefresh performs a refresh against the upstream LDAP identity provider + PerformRefresh(ctx context.Context, userDN string) error } type DynamicUpstreamIDPProvider interface { diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 30956524..dcec411b 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -89,14 +89,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, case psession.ProviderTypeOIDC: return upstreamOIDCRefresh(ctx, customSessionData, providerCache) case psession.ProviderTypeLDAP: - // upstream refresh not yet implemented for LDAP, so do nothing + return upstreamLDAPRefresh(ctx, customSessionData, providerCache) case psession.ProviderTypeActiveDirectory: - // upstream refresh not yet implemented for AD, so do nothing + return upstreamLDAPRefresh(ctx, customSessionData, providerCache) default: return errorsx.WithStack(errMissingUpstreamSessionInternalError) } - - return nil } func upstreamOIDCRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error { @@ -165,3 +163,54 @@ func findOIDCProviderByNameAndValidateUID( return nil, errorsx.WithStack(errUpstreamRefreshError. WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) } + +func upstreamLDAPRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error { + plog.Warning("refreshing upstream") + // if you have neither a valid ldap session config nor a valid active directory session config + if (s.LDAP == nil || s.LDAP.UserDN == "") && (s.ActiveDirectory == nil || s.ActiveDirectory.UserDN == "") { + return errorsx.WithStack(errMissingUpstreamSessionInternalError) + } + plog.Warning("going to find provider", "provider", s.ProviderName) + + // get ldap/ad provider out of cache + p, dn, _ := findLDAPProviderByNameAndValidateUID(s, providerCache) + // TODO error checking + // run PerformRefresh + err := p.PerformRefresh(ctx, dn) + if err != nil { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh failed using provider %q of type %q.", + s.ProviderName, s.ProviderType).WithWrap(err)) + } + + return nil +} + +func findLDAPProviderByNameAndValidateUID( + s *psession.CustomSessionData, + providerCache oidc.UpstreamIdentityProvidersLister, +) (provider.UpstreamLDAPIdentityProviderI, string, error) { + var providers []provider.UpstreamLDAPIdentityProviderI + var dn string + if s.ProviderType == psession.ProviderTypeLDAP { + providers = providerCache.GetLDAPIdentityProviders() + dn = s.LDAP.UserDN + } else if s.ProviderType == psession.ProviderTypeActiveDirectory { + providers = providerCache.GetActiveDirectoryIdentityProviders() + dn = s.ActiveDirectory.UserDN + } + + for _, p := range providers { + if p.GetName() == s.ProviderName { + if p.GetResourceUID() != s.ProviderUID { + return nil, "", errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Provider %q of type %q from upstream session data has changed its resource UID since authentication.", + s.ProviderName, s.ProviderType)) + } + return p, dn, nil + } + } + + return nil, "", errorsx.WithStack(errUpstreamRefreshError. + WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) +} diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index 3b6be1ff..c819847e 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -232,7 +232,7 @@ type tokenEndpointResponseExpectedValues struct { wantErrorResponseBody string wantRequestedScopes []string wantGrantedScopes []string - wantUpstreamOIDCRefreshCall *expectedUpstreamRefresh + wantUpstreamRefreshCall *expectedUpstreamRefresh wantUpstreamOIDCValidateTokenCall *expectedUpstreamValidateTokens wantCustomSessionDataStored *psession.CustomSessionData } @@ -879,8 +879,20 @@ func TestRefreshGrant(t *testing.T) { oidcUpstreamInitialRefreshToken = "initial-upstream-refresh-token" oidcUpstreamRefreshedIDToken = "fake-refreshed-id-token" oidcUpstreamRefreshedRefreshToken = "fake-refreshed-refresh-token" + + ldapUpstreamName = "some-ldap-idp" + ldapUpstreamResourceUID = "ldap-resource-uid" + ldapUpstreamType = "ldap" + ldapUpstreamDN = "some-ldap-user-dn" + + activeDirectoryUpstreamName = "some-ad-idp" + activeDirectoryUpstreamResourceUID = "ad-resource-uid" + activeDirectoryUpstreamType = "activedirectory" + activeDirectoryUpstreamDN = "some-ad-user-dn" ) + ldapUpstreamURL, _ := url.Parse("some-url") + // The below values are funcs so every test can have its own copy of the objects, to avoid data races // in these parallel tests. @@ -907,7 +919,7 @@ func TestRefreshGrant(t *testing.T) { return sessionData } - happyUpstreamRefreshCall := func() *expectedUpstreamRefresh { + happyOIDCUpstreamRefreshCall := func() *expectedUpstreamRefresh { return &expectedUpstreamRefresh{ performedByUpstreamName: oidcUpstreamName, args: &oidctestutil.PerformRefreshArgs{ @@ -917,6 +929,26 @@ func TestRefreshGrant(t *testing.T) { } } + happyLDAPUpstreamRefreshCall := func() *expectedUpstreamRefresh { + return &expectedUpstreamRefresh{ + performedByUpstreamName: ldapUpstreamName, + args: &oidctestutil.PerformRefreshArgs{ + Ctx: nil, + DN: ldapUpstreamDN, + }, + } + } + + happyActiveDirectoryUpstreamRefreshCall := func() *expectedUpstreamRefresh { + return &expectedUpstreamRefresh{ + performedByUpstreamName: activeDirectoryUpstreamName, + args: &oidctestutil.PerformRefreshArgs{ + Ctx: nil, + DN: activeDirectoryUpstreamDN, + }, + } + } + happyUpstreamValidateTokenCall := func(expectedTokens *oauth2.Token) *expectedUpstreamValidateTokens { return &expectedUpstreamValidateTokens{ performedByUpstreamName: oidcUpstreamName, @@ -944,7 +976,7 @@ func TestRefreshGrant(t *testing.T) { // same as the same values as the authcode exchange case. want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored) // Should always try to perform an upstream refresh. - want.wantUpstreamOIDCRefreshCall = happyUpstreamRefreshCall() + want.wantUpstreamRefreshCall = happyOIDCUpstreamRefreshCall() // Should only try to ValidateToken when there was an id token returned by the upstream refresh. if expectToValidateToken != nil { want.wantUpstreamOIDCValidateTokenCall = happyUpstreamValidateTokenCall(expectToValidateToken) @@ -952,6 +984,18 @@ func TestRefreshGrant(t *testing.T) { return want } + happyRefreshTokenResponseForLDAP := func(wantCustomSessionDataStored *psession.CustomSessionData) tokenEndpointResponseExpectedValues { + want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored) + want.wantUpstreamRefreshCall = happyLDAPUpstreamRefreshCall() + return want + } + + happyRefreshTokenResponseForActiveDirectory := func(wantCustomSessionDataStored *psession.CustomSessionData) tokenEndpointResponseExpectedValues { + want := happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(wantCustomSessionDataStored) + want.wantUpstreamRefreshCall = happyActiveDirectoryUpstreamRefreshCall() + return want + } + refreshedUpstreamTokensWithRefreshTokenWithoutIDToken := func() *oauth2.Token { return &oauth2.Token{ AccessToken: "fake-refreshed-access-token", @@ -973,10 +1017,11 @@ func TestRefreshGrant(t *testing.T) { } tests := []struct { - name string - idps *oidctestutil.UpstreamIDPListerBuilder - authcodeExchange authcodeExchangeInputs - refreshRequest refreshRequestInputs + name string + idps *oidctestutil.UpstreamIDPListerBuilder + authcodeExchange authcodeExchangeInputs + authEndpointInitialSessionData *psession.CustomSessionData + refreshRequest refreshRequestInputs }{ { name: "happy path refresh grant with openid scope granted (id token returned)", @@ -1015,7 +1060,7 @@ func TestRefreshGrant(t *testing.T) { wantSuccessBodyFields: []string{"refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"offline_access"}, wantGrantedScopes: []string{"offline_access"}, - wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), }, @@ -1096,7 +1141,7 @@ func TestRefreshGrant(t *testing.T) { wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantRequestedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, wantGrantedScopes: []string{"openid", "offline_access", "pinniped:request-audience"}, - wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantCustomSessionDataStored: upstreamOIDCCustomSessionDataWithNewRefreshToken(oidcUpstreamRefreshedRefreshToken), }, @@ -1449,8 +1494,8 @@ func TestRefreshGrant(t *testing.T) { }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), - wantStatus: http.StatusUnauthorized, + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), + wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { "error": "error", @@ -1474,7 +1519,7 @@ func TestRefreshGrant(t *testing.T) { }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantUpstreamOIDCRefreshCall: happyUpstreamRefreshCall(), + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` @@ -1486,6 +1531,322 @@ func TestRefreshGrant(t *testing.T) { }, }, }, + { + name: "upstream ldap refresh happy path", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: ldapUpstreamName, + ResourceUID: ldapUpstreamResourceUID, + URL: ldapUpstreamURL, + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: happyRefreshTokenResponseForLDAP( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + ), + }, + }, + { + name: "upstream active directory refresh happy path", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithActiveDirectory(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: activeDirectoryUpstreamName, + ResourceUID: activeDirectoryUpstreamResourceUID, + URL: ldapUpstreamURL, + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: happyRefreshTokenResponseForActiveDirectory( + &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + ), + }, + }, + { + name: "upstream ldap refresh when the LDAP session data is nil", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: ldapUpstreamName, + ResourceUID: ldapUpstreamResourceUID, + URL: ldapUpstreamURL, + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: nil, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: nil, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "There was an internal server error. Required upstream data not found in session." + } + `), + }, + }, + }, + { + name: "upstream active directory refresh when the ad session data is nil", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: activeDirectoryUpstreamName, + ResourceUID: activeDirectoryUpstreamResourceUID, + URL: ldapUpstreamURL, + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: nil, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + LDAP: nil, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "There was an internal server error. Required upstream data not found in session." + } + `), + }, + }, + }, + { + name: "upstream ldap refresh when the LDAP session data does not contain dn", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: ldapUpstreamName, + ResourceUID: ldapUpstreamResourceUID, + URL: ldapUpstreamURL, + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: "", + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: "", + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "There was an internal server error. Required upstream data not found in session." + } + `), + }, + }, + }, + { + name: "upstream active directory refresh when the active directory session data does not contain dn", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: ldapUpstreamName, + ResourceUID: ldapUpstreamResourceUID, + URL: ldapUpstreamURL, + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: "", + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: "", + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusInternalServerError, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "There was an internal server error. Required upstream data not found in session." + } + `), + }, + }, + }, + { + name: "upstream ldap refresh returns an error", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: ldapUpstreamName, + ResourceUID: ldapUpstreamResourceUID, + URL: ldapUpstreamURL, + PerformRefreshErr: errors.New("Some error performing upstream refresh"), + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed using provider 'some-ldap-idp' of type 'ldap'." + } + `), + }, + }, + }, + { + name: "upstream active directory refresh returns an error", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithActiveDirectory(&oidctestutil.TestUpstreamLDAPIdentityProvider{ + Name: activeDirectoryUpstreamName, + ResourceUID: activeDirectoryUpstreamResourceUID, + URL: ldapUpstreamURL, + PerformRefreshErr: errors.New("Some error performing upstream refresh"), + }), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed using provider 'some-ad-idp' of type 'activedirectory'." + } + `), + }, + }, + }, } for _, test := range tests { test := test @@ -1493,6 +1854,8 @@ func TestRefreshGrant(t *testing.T) { t.Parallel() // First exchange the authcode for tokens, including a refresh token. + // its actually fine to use this function even when simulating ldap (which uses a different flow) because it's + // just populating a secret in storage. subject, rsp, authCode, jwtSigningKey, secrets, oauthStore := exchangeAuthcodeForTokens(t, test.authcodeExchange, test.idps.Build()) var parsedAuthcodeExchangeResponseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody)) @@ -1525,11 +1888,11 @@ func TestRefreshGrant(t *testing.T) { t.Logf("second response body: %q", refreshResponse.Body.String()) // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. - if test.refreshRequest.want.wantUpstreamOIDCRefreshCall != nil { - test.refreshRequest.want.wantUpstreamOIDCRefreshCall.args.Ctx = reqContext + if test.refreshRequest.want.wantUpstreamRefreshCall != nil { + test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext test.idps.RequireExactlyOneCallToPerformRefresh(t, - test.refreshRequest.want.wantUpstreamOIDCRefreshCall.performedByUpstreamName, - test.refreshRequest.want.wantUpstreamOIDCRefreshCall.args, + test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName, + test.refreshRequest.want.wantUpstreamRefreshCall.args, ) } else { test.idps.RequireExactlyZeroCallsToPerformRefresh(t) diff --git a/internal/psession/pinniped_session.go b/internal/psession/pinniped_session.go index 72ea3bdb..8009f91d 100644 --- a/internal/psession/pinniped_session.go +++ b/internal/psession/pinniped_session.go @@ -45,6 +45,10 @@ type CustomSessionData struct { // Only used when ProviderType == "oidc". OIDC *OIDCSessionData `json:"oidc,omitempty"` + + LDAP *LDAPSessionData `json:"ldap,omitempty"` + + ActiveDirectory *ActiveDirectorySessionData `json:"activedirectory,omitempty"` } type ProviderType string @@ -60,6 +64,16 @@ type OIDCSessionData struct { UpstreamRefreshToken string `json:"upstreamRefreshToken"` } +// LDAPSessionData is the additional data needed by Pinniped when the upstream IDP is an LDAP provider. +type LDAPSessionData struct { + UserDN string `json:"userDN"` +} + +// ActiveDirectorySessionData is the additional data needed by Pinniped when the upstream IDP is an Active Directory provider. +type ActiveDirectorySessionData struct { + UserDN string `json:"userDN"` +} + // NewPinnipedSession returns a new empty session. func NewPinnipedSession() *PinnipedSession { return &PinnipedSession{ diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 90361ddd..84d160f1 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -63,6 +63,7 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct { type PerformRefreshArgs struct { Ctx context.Context RefreshToken string + DN string } // ValidateTokenArgs is used to spy on calls to @@ -74,10 +75,13 @@ type ValidateTokenArgs struct { } type TestUpstreamLDAPIdentityProvider struct { - Name string - ResourceUID types.UID - URL *url.URL - AuthenticateFunc func(ctx context.Context, username, password string) (*authenticator.Response, bool, error) + Name string + ResourceUID types.UID + URL *url.URL + AuthenticateFunc func(ctx context.Context, username, password string) (*authenticator.Response, bool, error) + performRefreshCallCount int + performRefreshArgs []*PerformRefreshArgs + PerformRefreshErr error } var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{} @@ -98,6 +102,32 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL { return u.URL } +func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, userDN string) error { + if u.performRefreshArgs == nil { + u.performRefreshArgs = make([]*PerformRefreshArgs, 0) + } + u.performRefreshCallCount++ + u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ + Ctx: ctx, + DN: userDN, + }) + if u.PerformRefreshErr != nil { + return u.PerformRefreshErr + } + return nil +} + +func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int { + return u.performRefreshCallCount +} + +func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs { + if u.performRefreshArgs == nil { + u.performRefreshArgs = make([]*PerformRefreshArgs, 0) + } + return u.performRefreshArgs[call] +} + type TestUpstreamOIDCIdentityProvider struct { Name string ClientID string @@ -390,31 +420,54 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( t.Helper() var actualArgs *PerformRefreshArgs var actualNameOfUpstreamWhichMadeCall string - actualCallCountAcrossAllOIDCUpstreams := 0 + actualCallCountAcrossAllUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount - actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + actualCallCountAcrossAllUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgs = upstreamOIDC.performRefreshArgs[0] } } - require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, - "should have been exactly one call to PerformRefresh() by all OIDC upstreams", + for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { + callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount + actualCallCountAcrossAllUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name + actualArgs = upstreamLDAP.performRefreshArgs[0] + } + } + for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders { + callCountOnThisUpstream := upstreamAD.performRefreshCallCount + actualCallCountAcrossAllUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamAD.Name + actualArgs = upstreamAD.performRefreshArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllUpstreams, + "should have been exactly one call to PerformRefresh() by all upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, - "PerformRefresh() was called on the wrong OIDC upstream", + "PerformRefresh() was called on the wrong upstream", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) { t.Helper() - actualCallCountAcrossAllOIDCUpstreams := 0 + actualCallCountAcrossAllUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { - actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.performRefreshCallCount + actualCallCountAcrossAllUpstreams += upstreamOIDC.performRefreshCallCount } - require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { + actualCallCountAcrossAllUpstreams += upstreamLDAP.performRefreshCallCount + } + for _, upstreamActiveDirectory := range b.upstreamActiveDirectoryIdentityProviders { + actualCallCountAcrossAllUpstreams += upstreamActiveDirectory.performRefreshCallCount + } + + require.Equal(t, 0, actualCallCountAcrossAllUpstreams, "expected exactly zero calls to PerformRefresh()", ) } diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 1baab58b..b0b05e07 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -169,6 +169,43 @@ func (p *Provider) GetConfig() ProviderConfig { return p.c } +func (p *Provider) PerformRefresh(ctx context.Context, userDN string) error { + t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) + defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches + search := p.refreshUserSearchRequest(userDN) + + conn, err := p.dial(ctx) + if err != nil { + p.traceAuthFailure(t, err) + return fmt.Errorf(`error dialing host "%s": %w`, p.c.Host, err) + } + defer conn.Close() + + err = conn.Bind(p.c.BindUsername, p.c.BindPassword) + if err != nil { + p.traceAuthFailure(t, err) + return fmt.Errorf(`error binding as "%s" before user search: %w`, p.c.BindUsername, err) + } + + searchResult, err := conn.Search(search) + + if err != nil { + return fmt.Errorf(`error searching for user "%s": %w`, userDN, err) + } + + // if any more or less than one entry, error. + // we don't need to worry about logging this because we know it's a dn. + if len(searchResult.Entries) != 1 { + return fmt.Errorf(`searching for user "%s" resulted in %d search results, but expected 1 result`, + userDN, len(searchResult.Entries), + ) + } + + // do nothing. if we got exactly one search result back then that means the user + // still exists. + return nil +} + func (p *Provider) dial(ctx context.Context) (Conn, error) { tlsAddr, err := endpointaddr.Parse(p.c.Host, defaultLDAPSPort) if err != nil { @@ -355,7 +392,7 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bi return nil, false, fmt.Errorf(`error binding as "%s" before user search: %w`, p.c.BindUsername, err) } - mappedUsername, mappedUID, mappedGroupNames, err := p.searchAndBindUser(conn, username, bindFunc) + mappedUsername, mappedUID, mappedGroupNames, userDN, err := p.searchAndBindUser(conn, username, bindFunc) if err != nil { p.traceAuthFailure(t, err) return nil, false, err @@ -371,6 +408,7 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bi Name: mappedUsername, UID: mappedUID, Groups: mappedGroupNames, + Extra: map[string][]string{"userDN": {userDN}}, }, } p.traceAuthSuccess(t) @@ -454,7 +492,7 @@ func (p *Provider) SearchForDefaultNamingContext(ctx context.Context) (string, e return searchBase, nil } -func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(conn Conn, foundUserDN string) error) (string, string, []string, error) { +func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(conn Conn, foundUserDN string) error) (string, string, []string, string, error) { searchResult, err := conn.Search(p.userSearchRequest(username)) if err != nil { plog.All(`error searching for user`, @@ -462,7 +500,7 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c "username", username, "err", err, ) - return "", "", nil, fmt.Errorf(`error searching for user: %w`, err) + return "", "", nil, "", fmt.Errorf(`error searching for user: %w`, err) } if len(searchResult.Entries) == 0 { if plog.Enabled(plog.LevelAll) { @@ -473,38 +511,38 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c } else { plog.Debug("error finding user: user not found (cowardly avoiding printing username because log level is not 'all')", "upstreamName", p.GetName()) } - return "", "", nil, nil + return "", "", nil, "", nil } // At this point, we have matched at least one entry, so we can be confident that the username is not actually // someone's password mistakenly entered into the username field, so we can log it without concern. if len(searchResult.Entries) > 1 { - return "", "", nil, fmt.Errorf(`searching for user "%s" resulted in %d search results, but expected 1 result`, + return "", "", nil, "", fmt.Errorf(`searching for user "%s" resulted in %d search results, but expected 1 result`, username, len(searchResult.Entries), ) } userEntry := searchResult.Entries[0] if len(userEntry.DN) == 0 { - return "", "", nil, fmt.Errorf(`searching for user "%s" resulted in search result without DN`, username) + return "", "", nil, "", fmt.Errorf(`searching for user "%s" resulted in search result without DN`, username) } mappedUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, username) if err != nil { - return "", "", nil, err + return "", "", nil, "", err } // We would like to support binary typed attributes for UIDs, so always read them as binary and encode them, // even when the attribute may not be binary. mappedUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, username) if err != nil { - return "", "", nil, err + return "", "", nil, "", err } mappedGroupNames := []string{} if len(p.c.GroupSearch.Base) > 0 { mappedGroupNames, err = p.searchGroupsForUserDN(conn, userEntry.DN) if err != nil { - return "", "", nil, err + return "", "", nil, "", err } } sort.Strings(mappedGroupNames) @@ -516,12 +554,12 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c err, "upstreamName", p.GetName(), "username", username, "dn", userEntry.DN) ldapErr := &ldap.Error{} if errors.As(err, &ldapErr) && ldapErr.ResultCode == ldap.LDAPResultInvalidCredentials { - return "", "", nil, nil + return "", "", nil, "", nil } - return "", "", nil, fmt.Errorf(`error binding for user "%s" using provided password against DN "%s": %w`, username, userEntry.DN, err) + return "", "", nil, "", fmt.Errorf(`error binding for user "%s" using provided password against DN "%s": %w`, username, userEntry.DN, err) } - return mappedUsername, mappedUID, mappedGroupNames, nil + return mappedUsername, mappedUID, mappedGroupNames, userEntry.DN, nil } func (p *Provider) defaultNamingContextRequest() *ldap.SearchRequest { @@ -568,6 +606,21 @@ func (p *Provider) groupSearchRequest(userDN string) *ldap.SearchRequest { } } +func (p *Provider) refreshUserSearchRequest(dn string) *ldap.SearchRequest { + // See https://ldap.com/the-ldap-search-operation for general documentation of LDAP search options. + return &ldap.SearchRequest{ + BaseDN: dn, + Scope: ldap.ScopeBaseObject, + DerefAliases: ldap.NeverDerefAliases, + SizeLimit: 2, + TimeLimit: 90, + TypesOnly: false, + Filter: "(objectClass=*)", // we already have the dn, so the filter doesn't matter + Attributes: []string{}, // TODO this will need to include some other AD attributes + Controls: nil, // this could be used to enable paging, but we're already limiting the result max size + } +} + func (p *Provider) userSearchRequestedAttributes() []string { attributes := []string{} if p.c.UserSearch.UsernameAttribute != distinguishedNameAttributeName { diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 0cb1f355..b7f953be 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -156,6 +156,7 @@ func TestEndUserAuthentication(t *testing.T) { Name: testUserSearchResultUsernameAttributeValue, UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)), Groups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2}, + Extra: map[string][]string{"userDN": {testUserSearchResultDNValue}}, } if editFunc != nil { editFunc(u) @@ -503,6 +504,7 @@ func TestEndUserAuthentication(t *testing.T) { Name: testUserSearchResultUsernameAttributeValue, UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)), Groups: []string{"a", "b", "c"}, + Extra: map[string][]string{"userDN": {testUserSearchResultDNValue}}, }, }, }, @@ -1212,6 +1214,151 @@ func TestEndUserAuthentication(t *testing.T) { } } +func TestUpstreamRefresh(t *testing.T) { + expectedUserSearch := &ldap.SearchRequest{ + BaseDN: testUserSearchResultDNValue, + Scope: ldap.ScopeBaseObject, + DerefAliases: ldap.NeverDerefAliases, + SizeLimit: 2, + TimeLimit: 90, + TypesOnly: false, + Filter: "(objectClass=*)", + Attributes: []string{}, + Controls: nil, // don't need paging because we set the SizeLimit so small + } + + happyPathUserSearchResult := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testUserSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{}, + }, + }, + } + + providerConfig := &ProviderConfig{ + Name: "some-provider-name", + Host: testHost, + CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test + ConnectionProtocol: TLS, + BindUsername: testBindUsername, + BindPassword: testBindPassword, + UserSearch: UserSearchConfig{ + Base: testUserSearchBase, + }, + } + + tests := []struct { + name string + providerConfig *ProviderConfig + setupMocks func(conn *mockldapconn.MockConn) + dialError error + wantErr string + }{ + { + name: "happy path where searching the dn returns a single entry", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + }, + { + name: "error where dial fails", + providerConfig: providerConfig, + dialError: errors.New("some dial error"), + wantErr: "error dialing host \"ldap.example.com:8443\": some dial error", + }, + { + name: "error binding", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantErr: "error binding as \"cn=some-bind-username,dc=pinniped,dc=dev\" before user search: some bind error", + }, + { + name: "search result returns no entries", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{}, + }, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantErr: "searching for user \"some-upstream-user-dn\" resulted in 0 search results, but expected 1 result", + }, + { + name: "error searching", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(nil, errors.New("some search error")) + conn.EXPECT().Close().Times(1) + }, + wantErr: "error searching for user \"some-upstream-user-dn\": some search error", + }, + { + name: "search result returns more than one entry", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testUserSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{}, + }, + { + DN: "doesn't-matter", + Attributes: []*ldap.EntryAttribute{}, + }, + }, + }, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantErr: "searching for user \"some-upstream-user-dn\" resulted in 2 search results, but expected 1 result", + }, + } + + for _, test := range tests { + tt := test + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + conn := mockldapconn.NewMockConn(ctrl) + if tt.setupMocks != nil { + tt.setupMocks(conn) + } + + dialWasAttempted := false + providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { + dialWasAttempted = true + require.Equal(t, providerConfig.Host, addr.Endpoint()) + if tt.dialError != nil { + return nil, tt.dialError + } + + return conn, nil + }) + + provider := New(*providerConfig) + err := provider.PerformRefresh(context.Background(), testUserSearchResultDNValue) + if tt.wantErr != "" { + require.NotNil(t, err) + require.Equal(t, tt.wantErr, err.Error()) + } else { + require.NoError(t, err) + } + require.Equal(t, true, dialWasAttempted) + }) + } +} + func TestTestConnection(t *testing.T) { providerConfig := func(editFunc func(p *ProviderConfig)) *ProviderConfig { config := &ProviderConfig{ diff --git a/test/integration/ldap_client_test.go b/test/integration/ldap_client_test.go index 9c21698c..8cafa810 100644 --- a/test/integration/ldap_client_test.go +++ b/test/integration/ldap_client_test.go @@ -83,7 +83,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { password: pinnyPassword, provider: upstreamldap.New(*providerConfig(nil)), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -95,7 +95,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.ConnectionProtocol = upstreamldap.StartTLS })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -104,7 +104,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "dc=pinniped,dc=dev" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -113,7 +113,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "(cn={})" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -125,7 +125,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.UserSearch.Filter = "cn={}" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "cn=pinny,ou=users,dc=pinniped,dc=dev", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "cn=pinny,ou=users,dc=pinniped,dc=dev", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -136,7 +136,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.UserSearch.Filter = "(|(cn={})(mail={}))" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -147,7 +147,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.UserSearch.Filter = "(|(cn={})(mail={}))" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -156,7 +156,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "dn" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("cn=pinny,ou=users,dc=pinniped,dc=dev"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("cn=pinny,ou=users,dc=pinniped,dc=dev"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -165,7 +165,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "sn" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("Seal"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("Seal"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -174,7 +174,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { password: pinnyPassword, provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "sn" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "Seal", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, // note that the final answer has case preserved from the entry + User: &user.DefaultInfo{Name: "Seal", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, // note that the final answer has case preserved from the entry }, }, { @@ -187,7 +187,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.UserSearch.UIDAttribute = "givenName" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "Pinny the 🦭", UID: b64("Pinny the 🦭"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "Pinny the 🦭", UID: b64("Pinny the 🦭"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -199,7 +199,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.UserSearch.UsernameAttribute = "cn" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -220,7 +220,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.Base = "" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -231,7 +231,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.Base = "ou=users,dc=pinniped,dc=dev" // there are no groups under this part of the tree })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -245,7 +245,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{ "cn=ball-game-players,ou=beach-groups,ou=groups,dc=pinniped,dc=dev", "cn=seals,ou=groups,dc=pinniped,dc=dev", - }}, + }, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -259,7 +259,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{ "cn=ball-game-players,ou=beach-groups,ou=groups,dc=pinniped,dc=dev", "cn=seals,ou=groups,dc=pinniped,dc=dev", - }}, + }, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -270,7 +270,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames", "groupOfNames"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames", "groupOfNames"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -281,7 +281,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.Filter = "(&(&(objectClass=groupOfNames)(member={}))(cn=seals))" })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -292,7 +292,7 @@ func TestLDAPSearch_Parallel(t *testing.T) { p.GroupSearch.Filter = "foobar={}" // foobar is not a valid attribute name for this LDAP server's schema })), wantAuthResponse: &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, }, { @@ -671,7 +671,7 @@ func TestSimultaneousLDAPRequestsOnSingleProvider(t *testing.T) { assert.NoError(t, result.err) assert.True(t, result.authenticated, "expected the user to be authenticated, but they were not") assert.Equal(t, &authenticator.Response{ - User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}}, + User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"ball-game-players", "seals"}, Extra: map[string][]string{"userDN": {"cn=pinny,ou=users,dc=pinniped,dc=dev"}}}, }, result.response) } } diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 770d67b7..33881325 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -214,7 +214,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.LDAP.UserDN) + customSessionData.LDAP.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamLDAP.Host+ @@ -281,7 +285,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.LDAP.UserDN) + customSessionData.LDAP.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamLDAP.StartTLSOnlyHost+ @@ -348,9 +356,13 @@ func TestSupervisorLogin(t *testing.T) { true, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type - wantErrorDescription: "The resource owner or authorization server denied the request. Username/password not accepted by LDAP provider.", - wantErrorType: "access_denied", + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.LDAP.UserDN) + customSessionData.LDAP.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, + wantErrorDescription: "The resource owner or authorization server denied the request. Username/password not accepted by LDAP provider.", + wantErrorType: "access_denied", }, { name: "ldap login still works after updating bind secret", @@ -426,7 +438,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.LDAP.UserDN) + customSessionData.LDAP.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamLDAP.Host+ @@ -525,7 +541,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.LDAP.UserDN) + customSessionData.LDAP.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamLDAP.Host+ @@ -580,7 +600,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeActiveDirectory, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.ActiveDirectory.UserDN) + customSessionData.ActiveDirectory.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamActiveDirectory.Host+ @@ -648,7 +672,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeActiveDirectory, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.ActiveDirectory.UserDN) + customSessionData.ActiveDirectory.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamActiveDirectory.Host+ @@ -721,7 +749,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeActiveDirectory, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.ActiveDirectory.UserDN) + customSessionData.ActiveDirectory.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamActiveDirectory.Host+ @@ -809,7 +841,11 @@ func TestSupervisorLogin(t *testing.T) { false, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: func(t *testing.T, customSessionData *psession.CustomSessionData) { + require.Equal(t, psession.ProviderTypeActiveDirectory, customSessionData.ProviderType) + require.NotEmpty(t, customSessionData.ActiveDirectory.UserDN) + customSessionData.ActiveDirectory.UserDN = "cn=not-a-user,dc=pinniped,dc=dev" + }, // the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta( "ldaps://"+env.SupervisorUpstreamActiveDirectory.Host+ @@ -864,7 +900,7 @@ func TestSupervisorLogin(t *testing.T) { true, ) }, - breakRefreshSessionData: nil, // upstream refresh not yet implemented for this IDP type + breakRefreshSessionData: nil, wantErrorDescription: "The resource owner or authorization server denied the request. Username/password not accepted by LDAP provider.", wantErrorType: "access_denied", },