diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index d22f9ad7..cc637b4d 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -487,10 +487,7 @@ func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken func downstreamSubjectFromUpstreamLDAP(ldapUpstream provider.UpstreamLDAPIdentityProviderI, authenticateResponse *authenticator.Response) string { ldapURL := *ldapUpstream.GetURL() - q := ldapURL.Query() - q.Set(oidc.IDTokenSubjectClaim, authenticateResponse.User.GetUID()) - ldapURL.RawQuery = q.Encode() - return ldapURL.String() + return downstreamsession.DownstreamLDAPSubject(authenticateResponse.User.GetUID(), ldapURL) } func userDNFromAuthenticatedResponse(authenticatedResponse *authenticator.Response) string { diff --git a/internal/oidc/downstreamsession/downstream_session.go b/internal/oidc/downstreamsession/downstream_session.go index 0fee5a78..8618ab57 100644 --- a/internal/oidc/downstreamsession/downstream_session.go +++ b/internal/oidc/downstreamsession/downstream_session.go @@ -169,6 +169,13 @@ func extractStringClaimValue(claimName string, upstreamIDPName string, idTokenCl return valueAsString, nil } +func DownstreamLDAPSubject(uid string, ldapURL url.URL) string { + q := ldapURL.Query() + q.Set(oidc.IDTokenSubjectClaim, uid) + ldapURL.RawQuery = q.Encode() + return ldapURL.String() +} + func downstreamSubjectFromUpstreamOIDC(upstreamIssuerAsString string, upstreamSubject string) string { return fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, url.QueryEscape(upstreamSubject)) } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index c5f2b7b1..d7aa02c0 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -90,7 +90,7 @@ type UpstreamLDAPIdentityProviderI interface { authenticators.UserAuthenticator // PerformRefresh performs a refresh against the upstream LDAP identity provider - PerformRefresh(ctx context.Context, userDN string) error + PerformRefresh(ctx context.Context, userDN string, expectedUsername string, expectedSubject string) error } type DynamicUpstreamIDPProvider interface { diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index dcec411b..fee06a63 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -6,6 +6,7 @@ package token import ( "context" + "fmt" "net/http" "github.com/ory/fosite" @@ -75,6 +76,18 @@ func NewHandler( func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { session := accessRequest.GetSession().(*psession.PinnipedSession) + fositeSession := session.Fosite + if fositeSession == nil { + return fmt.Errorf("fosite session not found") + } + claims := fositeSession.Claims + if claims == nil { + return fmt.Errorf("fosite session claims not found") + } + extra := claims.Extra + downstreamUsername := extra["username"].(string) + downstreamSubject := claims.Subject + customSessionData := session.Custom if customSessionData == nil { return errorsx.WithStack(errMissingUpstreamSessionInternalError) @@ -89,9 +102,9 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, case psession.ProviderTypeOIDC: return upstreamOIDCRefresh(ctx, customSessionData, providerCache) case psession.ProviderTypeLDAP: - return upstreamLDAPRefresh(ctx, customSessionData, providerCache) + return upstreamLDAPRefresh(ctx, customSessionData, providerCache, downstreamUsername, downstreamSubject) case psession.ProviderTypeActiveDirectory: - return upstreamLDAPRefresh(ctx, customSessionData, providerCache) + return upstreamLDAPRefresh(ctx, customSessionData, providerCache, downstreamUsername, downstreamSubject) default: return errorsx.WithStack(errMissingUpstreamSessionInternalError) } @@ -164,19 +177,19 @@ func findOIDCProviderByNameAndValidateUID( WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) } -func upstreamLDAPRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error { - plog.Warning("refreshing upstream") +func upstreamLDAPRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister, username string, subject string) error { // if you have neither a valid ldap session config nor a valid active directory session config if (s.LDAP == nil || s.LDAP.UserDN == "") && (s.ActiveDirectory == nil || s.ActiveDirectory.UserDN == "") { return errorsx.WithStack(errMissingUpstreamSessionInternalError) } - plog.Warning("going to find provider", "provider", s.ProviderName) // get ldap/ad provider out of cache - p, dn, _ := findLDAPProviderByNameAndValidateUID(s, providerCache) - // TODO error checking + p, dn, err := findLDAPProviderByNameAndValidateUID(s, providerCache) + if err != nil { + return err + } // run PerformRefresh - err := p.PerformRefresh(ctx, dn) + err = p.PerformRefresh(ctx, dn, username, subject) if err != nil { return errorsx.WithStack(errUpstreamRefreshError.WithHintf( "Upstream refresh failed using provider %q of type %q.", diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index c819847e..50825b80 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -933,8 +933,10 @@ func TestRefreshGrant(t *testing.T) { return &expectedUpstreamRefresh{ performedByUpstreamName: ldapUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, - DN: ldapUpstreamDN, + Ctx: nil, + DN: ldapUpstreamDN, + ExpectedSubject: goodSubject, + ExpectedUsername: goodUsername, }, } } @@ -943,8 +945,10 @@ func TestRefreshGrant(t *testing.T) { return &expectedUpstreamRefresh{ performedByUpstreamName: activeDirectoryUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, - DN: activeDirectoryUpstreamDN, + Ctx: nil, + DN: activeDirectoryUpstreamDN, + ExpectedSubject: goodSubject, + ExpectedUsername: goodUsername, }, } } @@ -1796,7 +1800,8 @@ func TestRefreshGrant(t *testing.T) { }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusUnauthorized, + wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(), + wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { "error": "error", @@ -1837,7 +1842,8 @@ func TestRefreshGrant(t *testing.T) { }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantStatus: http.StatusUnauthorized, + wantUpstreamRefreshCall: happyActiveDirectoryUpstreamRefreshCall(), + wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { "error": "error", @@ -1847,6 +1853,78 @@ func TestRefreshGrant(t *testing.T) { }, }, }, + { + name: "upstream ldap idp not found", + idps: oidctestutil.NewUpstreamIDPListerBuilder(), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: ldapUpstreamResourceUID, + ProviderName: ldapUpstreamName, + ProviderType: ldapUpstreamType, + LDAP: &psession.LDAPSessionData{ + UserDN: ldapUpstreamDN, + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Provider 'some-ldap-idp' of type 'ldap' from upstream session data was not found." + } + `), + }, + }, + }, + { + name: "upstream active directory idp not found", + idps: oidctestutil.NewUpstreamIDPListerBuilder(), + authcodeExchange: authcodeExchangeInputs{ + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + customSessionData: &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess( + &psession.CustomSessionData{ + ProviderUID: activeDirectoryUpstreamResourceUID, + ProviderName: activeDirectoryUpstreamName, + ProviderType: activeDirectoryUpstreamType, + ActiveDirectory: &psession.ActiveDirectorySessionData{ + UserDN: activeDirectoryUpstreamDN, + }, + }, + ), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Provider 'some-ad-idp' of type 'activedirectory' from upstream session data was not found." + } + `), + }, + }, + }, } for _, test := range tests { test := test diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 84d160f1..23b193f0 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -61,9 +61,11 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct { // PerformRefreshArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). type PerformRefreshArgs struct { - Ctx context.Context - RefreshToken string - DN string + Ctx context.Context + RefreshToken string + DN string + ExpectedUsername string + ExpectedSubject string } // ValidateTokenArgs is used to spy on calls to @@ -102,14 +104,16 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL { return u.URL } -func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, userDN string) error { +func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, userDN string, expectedUsername string, expectedSubject string) error { if u.performRefreshArgs == nil { u.performRefreshArgs = make([]*PerformRefreshArgs, 0) } u.performRefreshCallCount++ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ - Ctx: ctx, - DN: userDN, + Ctx: ctx, + DN: userDN, + ExpectedUsername: expectedUsername, + ExpectedSubject: expectedSubject, }) if u.PerformRefreshErr != nil { return u.PerformRefreshErr diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index b0b05e07..9f24f762 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -27,6 +27,7 @@ import ( "go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/endpointaddr" + "go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/plog" ) @@ -169,7 +170,7 @@ func (p *Provider) GetConfig() ProviderConfig { return p.c } -func (p *Provider) PerformRefresh(ctx context.Context, userDN string) error { +func (p *Provider) PerformRefresh(ctx context.Context, userDN string, expectedUsername string, expectedSubject string) error { t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches search := p.refreshUserSearchRequest(userDN) @@ -201,6 +202,30 @@ func (p *Provider) PerformRefresh(ctx context.Context, userDN string) error { ) } + userEntry := searchResult.Entries[0] + if len(userEntry.DN) == 0 { + return fmt.Errorf(`searching for user with original DN "%s" resulted in search result without DN`, userDN) + } + + newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN) + if err != nil { + return err // TODO test having no values or more than one maybe + } + if newUsername != expectedUsername { + return fmt.Errorf(`searching for user "%s" returned a different username than the previous value. expected: "%s", actual: "%s"`, + userDN, expectedUsername, newUsername, + ) + } + + newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN) + if err != nil { + return err // TODO test + } + newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL()) + if newSubject != expectedSubject { + return fmt.Errorf(`searching for user "%s" produced a different subject than the previous value. expected: "%s", actual: "%s"`, userDN, expectedSubject, newSubject) + } + // do nothing. if we got exactly one search result back then that means the user // still exists. return nil @@ -615,9 +640,9 @@ func (p *Provider) refreshUserSearchRequest(dn string) *ldap.SearchRequest { SizeLimit: 2, TimeLimit: 90, TypesOnly: false, - Filter: "(objectClass=*)", // we already have the dn, so the filter doesn't matter - Attributes: []string{}, // TODO this will need to include some other AD attributes - Controls: nil, // this could be used to enable paging, but we're already limiting the result max size + Filter: "(objectClass=*)", // we already have the dn, so the filter doesn't matter + Attributes: p.userSearchRequestedAttributes(), // TODO this will need to include some other AD attributes + Controls: nil, // this could be used to enable paging, but we're already limiting the result max size } } diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index b7f953be..259c6e94 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -1223,15 +1223,25 @@ func TestUpstreamRefresh(t *testing.T) { TimeLimit: 90, TypesOnly: false, Filter: "(objectClass=*)", - Attributes: []string{}, + Attributes: []string{testUserSearchUsernameAttribute, testUserSearchUIDAttribute}, Controls: nil, // don't need paging because we set the SizeLimit so small } happyPathUserSearchResult := &ldap.SearchResult{ Entries: []*ldap.Entry{ { - DN: testUserSearchResultDNValue, - Attributes: []*ldap.EntryAttribute{}, + DN: testUserSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + { + Name: testUserSearchUsernameAttribute, + Values: []string{testUserSearchResultUsernameAttributeValue}, + }, + { + Name: testUserSearchUIDAttribute, + Values: []string{testUserSearchResultUIDAttributeValue}, + ByteValues: [][]byte{[]byte(testUserSearchResultUIDAttributeValue)}, + }, + }, }, }, } @@ -1244,7 +1254,9 @@ func TestUpstreamRefresh(t *testing.T) { BindUsername: testBindUsername, BindPassword: testBindPassword, UserSearch: UserSearchConfig{ - Base: testUserSearchBase, + Base: testUserSearchBase, + UIDAttribute: testUserSearchUIDAttribute, + UsernameAttribute: testUserSearchUsernameAttribute, }, } @@ -1322,6 +1334,80 @@ func TestUpstreamRefresh(t *testing.T) { }, wantErr: "searching for user \"some-upstream-user-dn\" resulted in 2 search results, but expected 1 result", }, + { + name: "search result has wrong uid", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testUserSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + { + Name: testUserSearchUsernameAttribute, + Values: []string{testUserSearchResultUsernameAttributeValue}, + }, + { + Name: testUserSearchUIDAttribute, + Values: []string{"wrong-uid"}, + ByteValues: [][]byte{[]byte("wrong-uid")}, + }, + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantErr: "searching for user \"some-upstream-user-dn\" produced a different subject than the previous value. expected: \"ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU\", actual: \"ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=d3JvbmctdWlk\"", + }, + { + name: "search result has wrong username", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testUserSearchResultDNValue, + Attributes: []*ldap.EntryAttribute{ + { + Name: testUserSearchUsernameAttribute, + Values: []string{"wrong-username"}, + }, + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantErr: "searching for user \"some-upstream-user-dn\" returned a different username than the previous value. expected: \"some-upstream-username-value\", actual: \"wrong-username\"", + }, + { + name: "search result has no dn", + providerConfig: providerConfig, + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: testUserSearchUsernameAttribute, + Values: []string{testUserSearchResultUsernameAttributeValue}, + }, + { + Name: testUserSearchUIDAttribute, + Values: []string{testUserSearchResultUIDAttributeValue}, + }, + }, + }, + }, + }, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + wantErr: "searching for user with original DN \"some-upstream-user-dn\" resulted in search result without DN", + }, } for _, test := range tests { @@ -1347,9 +1433,10 @@ func TestUpstreamRefresh(t *testing.T) { }) provider := New(*providerConfig) - err := provider.PerformRefresh(context.Background(), testUserSearchResultDNValue) + subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU" + err := provider.PerformRefresh(context.Background(), testUserSearchResultDNValue, testUserSearchResultUsernameAttributeValue, subject) if tt.wantErr != "" { - require.NotNil(t, err) + require.Error(t, err) require.Equal(t, tt.wantErr, err.Error()) } else { require.NoError(t, err)