diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 76727a12..c0044fc5 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -303,9 +303,12 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit return err } subject := session.Fosite.Claims.Subject - oldGroups, err := getDownstreamGroupsFromPinnipedSession(session) - if err != nil { - return err + var oldGroups []string + if slices.Contains(grantedScopes, oidc.DownstreamGroupsScope) { + oldGroups, err = getDownstreamGroupsFromPinnipedSession(session) + if err != nil { + return err + } } s := session.Custom @@ -410,7 +413,7 @@ func getDownstreamGroupsFromPinnipedSession(session *psession.PinnipedSession) ( } downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim] if downstreamGroupsInterface == nil { - return nil, nil + return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError()) } downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{}) if !ok { diff --git a/internal/upstreamldap/upstreamldap.go b/internal/upstreamldap/upstreamldap.go index 7317d939..cfbd437f 100644 --- a/internal/upstreamldap/upstreamldap.go +++ b/internal/upstreamldap/upstreamldap.go @@ -16,19 +16,17 @@ import ( "strings" "time" - "go.pinniped.dev/internal/oidc" - - "k8s.io/utils/strings/slices" - "github.com/go-ldap/ldap/v3" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/authentication/user" + "k8s.io/utils/strings/slices" "k8s.io/utils/trace" "go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/endpointaddr" + "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/plog" diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index 9a9ca782..892c9f25 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -174,6 +174,7 @@ func TestEndUserAuthentication(t *testing.T) { name string username string password string + grantedScopes []string providerConfig *ProviderConfig searchMocks func(conn *mockldapconn.MockConn) bindEndUserMocks func(conn *mockldapconn.MockConn) @@ -286,6 +287,25 @@ func TestEndUserAuthentication(t *testing.T) { info.Groups = []string{} }), }, + { + name: "when groups scope isn't granted, don't do group search", + username: testUpstreamUsername, + password: testUpstreamPassword, + grantedScopes: []string{}, + providerConfig: providerConfig(nil), + searchMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)).Return(exampleUserSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + bindEndUserMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testUserSearchResultDNValue, testUpstreamPassword).Times(1) + }, + wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { + info := r.User.(*user.DefaultInfo) + info.Groups = nil + }), + }, { name: "when the UsernameAttribute is dn and there is a user search filter provided", username: testUpstreamUsername, @@ -1167,7 +1187,11 @@ func TestEndUserAuthentication(t *testing.T) { ldapProvider := New(*tt.providerConfig) - authResponse, authenticated, err := ldapProvider.AuthenticateUser(context.Background(), tt.username, tt.password, []string{"groups"}) + if tt.grantedScopes == nil { + tt.grantedScopes = []string{"groups"} + } + + authResponse, authenticated, err := ldapProvider.AuthenticateUser(context.Background(), tt.username, tt.password, tt.grantedScopes) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) switch { case tt.wantError != "": @@ -1199,7 +1223,7 @@ func TestEndUserAuthentication(t *testing.T) { } // Skip tt.bindEndUserMocks since DryRunAuthenticateUser() never binds as the end user. - authResponse, authenticated, err = ldapProvider.DryRunAuthenticateUser(context.Background(), tt.username, []string{"groups"}) + authResponse, authenticated, err = ldapProvider.DryRunAuthenticateUser(context.Background(), tt.username, tt.grantedScopes) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) switch { case tt.wantError != "": @@ -1331,6 +1355,7 @@ func TestUpstreamRefresh(t *testing.T) { tests := []struct { name string providerConfig *ProviderConfig + grantedScopes []string setupMocks func(conn *mockldapconn.MockConn) refreshUserDN string dialError error @@ -1465,6 +1490,17 @@ func TestUpstreamRefresh(t *testing.T) { }, wantGroups: nil, // do not update groups }, + { + name: "happy path where group search is configured but groups scope isn't included", + providerConfig: providerConfig(nil), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + grantedScopes: []string{}, + wantGroups: nil, + }, { name: "error where dial fails", providerConfig: providerConfig(nil), @@ -1769,6 +1805,9 @@ func TestUpstreamRefresh(t *testing.T) { tt.refreshUserDN = testUserSearchResultDNValue // default for all tests } + if tt.grantedScopes == nil { + tt.grantedScopes = []string{"groups"} + } initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000")) ldapProvider := New(*tt.providerConfig) subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU" @@ -1777,7 +1816,7 @@ func TestUpstreamRefresh(t *testing.T) { Subject: subject, DN: tt.refreshUserDN, AdditionalAttributes: map[string]string{pwdLastSetAttribute: initialPwdLastSetEncoded}, - GrantedScopes: []string{"groups"}, + GrantedScopes: tt.grantedScopes, }) if tt.wantErr != "" { require.Error(t, err)