Add a couple tests, address pr comments

Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
Margo Crawford 2022-06-22 14:19:55 -07:00
parent f2005b4c7f
commit dac0395680
3 changed files with 51 additions and 11 deletions

View File

@ -303,9 +303,12 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
return err return err
} }
subject := session.Fosite.Claims.Subject subject := session.Fosite.Claims.Subject
oldGroups, err := getDownstreamGroupsFromPinnipedSession(session) var oldGroups []string
if err != nil { if slices.Contains(grantedScopes, oidc.DownstreamGroupsScope) {
return err oldGroups, err = getDownstreamGroupsFromPinnipedSession(session)
if err != nil {
return err
}
} }
s := session.Custom s := session.Custom
@ -410,7 +413,7 @@ func getDownstreamGroupsFromPinnipedSession(session *psession.PinnipedSession) (
} }
downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim] downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim]
if downstreamGroupsInterface == nil { if downstreamGroupsInterface == nil {
return nil, nil return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{}) downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{})
if !ok { if !ok {

View File

@ -16,19 +16,17 @@ import (
"strings" "strings"
"time" "time"
"go.pinniped.dev/internal/oidc"
"k8s.io/utils/strings/slices"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/utils/strings/slices"
"k8s.io/utils/trace" "k8s.io/utils/trace"
"go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"

View File

@ -174,6 +174,7 @@ func TestEndUserAuthentication(t *testing.T) {
name string name string
username string username string
password string password string
grantedScopes []string
providerConfig *ProviderConfig providerConfig *ProviderConfig
searchMocks func(conn *mockldapconn.MockConn) searchMocks func(conn *mockldapconn.MockConn)
bindEndUserMocks func(conn *mockldapconn.MockConn) bindEndUserMocks func(conn *mockldapconn.MockConn)
@ -286,6 +287,25 @@ func TestEndUserAuthentication(t *testing.T) {
info.Groups = []string{} info.Groups = []string{}
}), }),
}, },
{
name: "when groups scope isn't granted, don't do group search",
username: testUpstreamUsername,
password: testUpstreamPassword,
grantedScopes: []string{},
providerConfig: providerConfig(nil),
searchMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch(nil)).Return(exampleUserSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
bindEndUserMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testUserSearchResultDNValue, testUpstreamPassword).Times(1)
},
wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) {
info := r.User.(*user.DefaultInfo)
info.Groups = nil
}),
},
{ {
name: "when the UsernameAttribute is dn and there is a user search filter provided", name: "when the UsernameAttribute is dn and there is a user search filter provided",
username: testUpstreamUsername, username: testUpstreamUsername,
@ -1167,7 +1187,11 @@ func TestEndUserAuthentication(t *testing.T) {
ldapProvider := New(*tt.providerConfig) ldapProvider := New(*tt.providerConfig)
authResponse, authenticated, err := ldapProvider.AuthenticateUser(context.Background(), tt.username, tt.password, []string{"groups"}) if tt.grantedScopes == nil {
tt.grantedScopes = []string{"groups"}
}
authResponse, authenticated, err := ldapProvider.AuthenticateUser(context.Background(), tt.username, tt.password, tt.grantedScopes)
require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted)
switch { switch {
case tt.wantError != "": case tt.wantError != "":
@ -1199,7 +1223,7 @@ func TestEndUserAuthentication(t *testing.T) {
} }
// Skip tt.bindEndUserMocks since DryRunAuthenticateUser() never binds as the end user. // Skip tt.bindEndUserMocks since DryRunAuthenticateUser() never binds as the end user.
authResponse, authenticated, err = ldapProvider.DryRunAuthenticateUser(context.Background(), tt.username, []string{"groups"}) authResponse, authenticated, err = ldapProvider.DryRunAuthenticateUser(context.Background(), tt.username, tt.grantedScopes)
require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted)
switch { switch {
case tt.wantError != "": case tt.wantError != "":
@ -1331,6 +1355,7 @@ func TestUpstreamRefresh(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
providerConfig *ProviderConfig providerConfig *ProviderConfig
grantedScopes []string
setupMocks func(conn *mockldapconn.MockConn) setupMocks func(conn *mockldapconn.MockConn)
refreshUserDN string refreshUserDN string
dialError error dialError error
@ -1465,6 +1490,17 @@ func TestUpstreamRefresh(t *testing.T) {
}, },
wantGroups: nil, // do not update groups wantGroups: nil, // do not update groups
}, },
{
name: "happy path where group search is configured but groups scope isn't included",
providerConfig: providerConfig(nil),
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
grantedScopes: []string{},
wantGroups: nil,
},
{ {
name: "error where dial fails", name: "error where dial fails",
providerConfig: providerConfig(nil), providerConfig: providerConfig(nil),
@ -1769,6 +1805,9 @@ func TestUpstreamRefresh(t *testing.T) {
tt.refreshUserDN = testUserSearchResultDNValue // default for all tests tt.refreshUserDN = testUserSearchResultDNValue // default for all tests
} }
if tt.grantedScopes == nil {
tt.grantedScopes = []string{"groups"}
}
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000")) initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
ldapProvider := New(*tt.providerConfig) ldapProvider := New(*tt.providerConfig)
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU" subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
@ -1777,7 +1816,7 @@ func TestUpstreamRefresh(t *testing.T) {
Subject: subject, Subject: subject,
DN: tt.refreshUserDN, DN: tt.refreshUserDN,
AdditionalAttributes: map[string]string{pwdLastSetAttribute: initialPwdLastSetEncoded}, AdditionalAttributes: map[string]string{pwdLastSetAttribute: initialPwdLastSetEncoded},
GrantedScopes: []string{"groups"}, GrantedScopes: tt.grantedScopes,
}) })
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(t, err) require.Error(t, err)