Merge branch 'main' into proposal_process

This commit is contained in:
Ryan Richard 2022-02-17 12:48:58 -08:00 committed by GitHub
commit dec89b5378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 549 additions and 96 deletions

View File

@ -108,7 +108,7 @@ type UpstreamLDAPIdentityProviderI interface {
authenticators.UserAuthenticator authenticators.UserAuthenticator
// PerformRefresh performs a refresh against the upstream LDAP identity provider // PerformRefresh performs a refresh against the upstream LDAP identity provider
PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) error PerformRefresh(ctx context.Context, storedRefreshAttributes StoredRefreshAttributes) (groups []string, err error)
} }
type StoredRefreshAttributes struct { type StoredRefreshAttributes struct {

View File

@ -301,7 +301,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
} }
// run PerformRefresh // run PerformRefresh
err = p.PerformRefresh(ctx, provider.StoredRefreshAttributes{ groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
Username: username, Username: username,
Subject: subject, Subject: subject,
DN: dn, DN: dn,
@ -312,6 +312,8 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
"Upstream refresh failed.").WithWrap(err). "Upstream refresh failed.").WithWrap(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
} }
// Replace the old value with the new value.
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
return nil return nil
} }

View File

@ -1339,6 +1339,60 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
}, },
{
name: "happy path refresh grant when the upstream refresh returns new group memberships from LDAP, it updates groups",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
Name: ldapUpstreamName,
ResourceUID: ldapUpstreamResourceUID,
URL: ldapUpstreamURL,
PerformRefreshGroups: []string{"new-group1", "new-group2", "new-group3"},
}),
authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
customSessionData: happyLDAPCustomSessionData,
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
happyLDAPCustomSessionData,
),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"openid", "offline_access"},
wantGrantedScopes: []string{"openid", "offline_access"},
wantGroups: []string{"new-group1", "new-group2", "new-group3"},
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
wantCustomSessionDataStored: happyLDAPCustomSessionData,
},
},
},
{
name: "happy path refresh grant when the upstream refresh returns empty list of group memberships from LDAP, it updates groups to an empty list",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
Name: ldapUpstreamName,
ResourceUID: ldapUpstreamResourceUID,
URL: ldapUpstreamURL,
PerformRefreshGroups: []string{},
}),
authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
customSessionData: happyLDAPCustomSessionData,
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
happyLDAPCustomSessionData,
),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusOK,
wantSuccessBodyFields: []string{"refresh_token", "access_token", "id_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"openid", "offline_access"},
wantGrantedScopes: []string{"openid", "offline_access"},
wantGroups: []string{},
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
wantCustomSessionDataStored: happyLDAPCustomSessionData,
},
},
},
{ {
name: "error from refresh grant when the upstream refresh does not return new group memberships from the merged ID token and userinfo results by returning group claim with illegal nil value", name: "error from refresh grant when the upstream refresh does not return new group memberships from the merged ID token and userinfo results by returning group claim with illegal nil value",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
@ -1970,6 +2024,7 @@ func TestRefreshGrant(t *testing.T) {
Name: ldapUpstreamName, Name: ldapUpstreamName,
ResourceUID: ldapUpstreamResourceUID, ResourceUID: ldapUpstreamResourceUID,
URL: ldapUpstreamURL, URL: ldapUpstreamURL,
PerformRefreshGroups: goodGroups,
}), }),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
@ -1990,6 +2045,7 @@ func TestRefreshGrant(t *testing.T) {
Name: activeDirectoryUpstreamName, Name: activeDirectoryUpstreamName,
ResourceUID: activeDirectoryUpstreamResourceUID, ResourceUID: activeDirectoryUpstreamResourceUID,
URL: ldapUpstreamURL, URL: ldapUpstreamURL,
PerformRefreshGroups: goodGroups,
}), }),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },

View File

@ -101,6 +101,7 @@ type TestUpstreamLDAPIdentityProvider struct {
performRefreshCallCount int performRefreshCallCount int
performRefreshArgs []*PerformRefreshArgs performRefreshArgs []*PerformRefreshArgs
PerformRefreshErr error PerformRefreshErr error
PerformRefreshGroups []string
} }
var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{} var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{}
@ -121,7 +122,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
return u.URL return u.URL
} }
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error { func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
if u.performRefreshArgs == nil { if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0) u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
} }
@ -133,9 +134,9 @@ func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, s
ExpectedSubject: storedRefreshAttributes.Subject, ExpectedSubject: storedRefreshAttributes.Subject,
}) })
if u.PerformRefreshErr != nil { if u.PerformRefreshErr != nil {
return u.PerformRefreshErr return nil, u.PerformRefreshErr
} }
return nil return u.PerformRefreshGroups, nil
} }
func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int { func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int {

View File

@ -13,12 +13,12 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"sort"
"strings" "strings"
"time" "time"
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/utils/trace" "k8s.io/utils/trace"
@ -170,61 +170,11 @@ func (p *Provider) GetConfig() ProviderConfig {
return p.c return p.c
} }
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error { func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) ([]string, error) {
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
userDN := storedRefreshAttributes.DN userDN := storedRefreshAttributes.DN
searchResult, err := p.performRefresh(ctx, userDN)
if err != nil {
p.traceRefreshFailure(t, err)
return err
}
// if any more or less than one entry, error.
// we don't need to worry about logging this because we know it's a dn.
if len(searchResult.Entries) != 1 {
return fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
userDN, len(searchResult.Entries),
)
}
userEntry := searchResult.Entries[0]
if len(userEntry.DN) == 0 {
return fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
}
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
if err != nil {
return err
}
if newUsername != storedRefreshAttributes.Username {
return fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
userDN, storedRefreshAttributes.Username, newUsername,
)
}
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
if err != nil {
return err
}
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
if newSubject != storedRefreshAttributes.Subject {
return fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
}
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
err = validateFunc(userEntry, storedRefreshAttributes)
if err != nil {
return fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
}
}
// we checked that the user still exists and their information is the same, so just return.
return nil
}
func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.SearchResult, error) {
search := p.refreshUserSearchRequest(userDN)
conn, err := p.dial(ctx) conn, err := p.dial(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err) return nil, fmt.Errorf(`error dialing host %q: %w`, p.c.Host, err)
@ -236,6 +186,60 @@ func (p *Provider) performRefresh(ctx context.Context, userDN string) (*ldap.Sea
return nil, fmt.Errorf(`error binding as %q before user search: %w`, p.c.BindUsername, err) return nil, fmt.Errorf(`error binding as %q before user search: %w`, p.c.BindUsername, err)
} }
searchResult, err := p.performUserRefreshSearch(conn, userDN)
if err != nil {
p.traceRefreshFailure(t, err)
return nil, err
}
// if any more or less than one entry, error.
// we don't need to worry about logging this because we know it's a dn.
if len(searchResult.Entries) != 1 {
return nil, fmt.Errorf(`searching for user %q resulted in %d search results, but expected 1 result`,
userDN, len(searchResult.Entries),
)
}
userEntry := searchResult.Entries[0]
if len(userEntry.DN) == 0 {
return nil, fmt.Errorf(`searching for user with original DN %q resulted in search result without DN`, userDN)
}
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
if err != nil {
return nil, err
}
if newUsername != storedRefreshAttributes.Username {
return nil, fmt.Errorf(`searching for user %q returned a different username than the previous value. expected: %q, actual: %q`,
userDN, storedRefreshAttributes.Username, newUsername,
)
}
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
if err != nil {
return nil, err
}
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
if newSubject != storedRefreshAttributes.Subject {
return nil, fmt.Errorf(`searching for user %q produced a different subject than the previous value. expected: %q, actual: %q`, userDN, storedRefreshAttributes.Subject, newSubject)
}
for attribute, validateFunc := range p.c.RefreshAttributeChecks {
err = validateFunc(userEntry, storedRefreshAttributes)
if err != nil {
return nil, fmt.Errorf(`validation for attribute %q failed during upstream refresh: %w`, attribute, err)
}
}
mappedGroupNames, err := p.searchGroupsForUserDN(conn, userDN)
if err != nil {
return nil, err
}
return mappedGroupNames, nil
}
func (p *Provider) performUserRefreshSearch(conn Conn, userDN string) (*ldap.SearchResult, error) {
search := p.refreshUserSearchRequest(userDN)
searchResult, err := conn.Search(search) searchResult, err := conn.Search(search)
if err != nil { if err != nil {
@ -445,6 +449,11 @@ func (p *Provider) authenticateUserImpl(ctx context.Context, username string, bi
} }
func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, error) { func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, error) {
// If we do not have group search configured, skip this search.
if len(p.c.GroupSearch.Base) == 0 {
return []string{}, nil
}
searchResult, err := conn.SearchWithPaging(p.groupSearchRequest(userDN), groupSearchPageSize) searchResult, err := conn.SearchWithPaging(p.groupSearchRequest(userDN), groupSearchPageSize)
if err != nil { if err != nil {
return nil, fmt.Errorf(`error searching for group memberships for user with DN %q: %w`, userDN, err) return nil, fmt.Errorf(`error searching for group memberships for user with DN %q: %w`, userDN, err)
@ -455,7 +464,7 @@ func (p *Provider) searchGroupsForUserDN(conn Conn, userDN string) ([]string, er
groupAttributeName = distinguishedNameAttributeName groupAttributeName = distinguishedNameAttributeName
} }
var groups []string groups := []string{}
entries: entries:
for _, groupEntry := range searchResult.Entries { for _, groupEntry := range searchResult.Entries {
if len(groupEntry.DN) == 0 { if len(groupEntry.DN) == 0 {
@ -476,8 +485,9 @@ entries:
} }
groups = append(groups, mappedGroupName) groups = append(groups, mappedGroupName)
} }
// de-duplicate the list of groups by turning it into a set,
return groups, nil // then turn it back into a sorted list.
return sets.NewString(groups...).List(), nil
} }
func (p *Provider) validateConfig() error { func (p *Provider) validateConfig() error {
@ -567,14 +577,10 @@ func (p *Provider) searchAndBindUser(conn Conn, username string, bindFunc func(c
return nil, err return nil, err
} }
var mappedGroupNames []string mappedGroupNames, err := p.searchGroupsForUserDN(conn, userEntry.DN)
if len(p.c.GroupSearch.Base) > 0 {
mappedGroupNames, err = p.searchGroupsForUserDN(conn, userEntry.DN)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
sort.Strings(mappedGroupNames)
mappedRefreshAttributes := make(map[string]string) mappedRefreshAttributes := make(map[string]string)
for k := range p.c.RefreshAttributeChecks { for k := range p.c.RefreshAttributeChecks {

View File

@ -237,6 +237,33 @@ func TestEndUserAuthentication(t *testing.T) {
}, },
wantAuthResponse: expectedAuthResponse(nil), wantAuthResponse: expectedAuthResponse(nil),
}, },
{
name: "when the group search has an override func",
username: testUpstreamUsername,
password: testUpstreamPassword,
providerConfig: providerConfig(func(p *ProviderConfig) {
p.GroupAttributeParsingOverrides = map[string]func(*ldap.Entry) (string, error){testGroupSearchGroupNameAttribute: func(entry *ldap.Entry) (string, error) {
return "something-else-" + entry.DN, nil
}}
}),
searchMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch(nil)).Return(exampleUserSearchResult, nil).Times(1)
conn.EXPECT().SearchWithPaging(expectedGroupSearch(nil), expectedGroupSearchPageSize).
Return(exampleGroupSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
bindEndUserMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testUserSearchResultDNValue, testUpstreamPassword).Times(1)
},
wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) {
r.User = &user.DefaultInfo{
Name: testUserSearchResultUsernameAttributeValue,
UID: base64.RawURLEncoding.EncodeToString([]byte(testUserSearchResultUIDAttributeValue)),
Groups: []string{"something-else-some-upstream-group-dn1", "something-else-some-upstream-group-dn2"},
}
}),
},
{ {
name: "when the group search base is empty then skip the group search entirely", name: "when the group search base is empty then skip the group search entirely",
username: testUpstreamUsername, username: testUpstreamUsername,
@ -254,7 +281,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, },
wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) {
info := r.User.(*user.DefaultInfo) info := r.User.(*user.DefaultInfo)
info.Groups = nil info.Groups = []string{}
}), }),
}, },
{ {
@ -958,6 +985,24 @@ func TestEndUserAuthentication(t *testing.T) {
}, },
wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{
name: "when the group search has an override func that errors",
username: testUpstreamUsername,
password: testUpstreamPassword,
providerConfig: providerConfig(func(p *ProviderConfig) {
p.GroupAttributeParsingOverrides = map[string]func(*ldap.Entry) (string, error){testGroupSearchGroupNameAttribute: func(entry *ldap.Entry) (string, error) {
return "", errors.New("some error")
}}
}),
searchMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch(nil)).Return(exampleUserSearchResult, nil).Times(1)
conn.EXPECT().SearchWithPaging(expectedGroupSearch(nil), expectedGroupSearchPageSize).
Return(exampleGroupSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantError: fmt.Sprintf("error finding groups for user %s: some error", testUserSearchResultDNValue),
},
{ {
name: "when binding as the found user returns an error", name: "when binding as the found user returns an error",
username: testUpstreamUsername, username: testUpstreamUsername,
@ -1100,6 +1145,18 @@ func TestUpstreamRefresh(t *testing.T) {
Controls: nil, // don't need paging because we set the SizeLimit so small Controls: nil, // don't need paging because we set the SizeLimit so small
} }
expectedGroupSearch := &ldap.SearchRequest{
BaseDN: testGroupSearchBase,
Scope: ldap.ScopeWholeSubtree,
DerefAliases: ldap.NeverDerefAliases,
SizeLimit: 0, // unlimited size because we will search with paging
TimeLimit: 90,
TypesOnly: false,
Filter: testGroupSearchFilterInterpolated,
Attributes: []string{testGroupSearchGroupNameAttribute},
Controls: nil, // nil because ldap.SearchWithPaging() will set the appropriate controls for us
}
happyPathUserSearchResult := &ldap.SearchResult{ happyPathUserSearchResult := &ldap.SearchResult{
Entries: []*ldap.Entry{ Entries: []*ldap.Entry{
{ {
@ -1146,6 +1203,7 @@ func TestUpstreamRefresh(t *testing.T) {
setupMocks func(conn *mockldapconn.MockConn) setupMocks func(conn *mockldapconn.MockConn)
dialError error dialError error
wantErr string wantErr string
wantGroups []string
}{ }{
{ {
name: "happy path where searching the dn returns a single entry", name: "happy path where searching the dn returns a single entry",
@ -1155,6 +1213,90 @@ func TestUpstreamRefresh(t *testing.T) {
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantGroups: []string{},
},
{
name: "happy path where group search returns groups",
providerConfig: &ProviderConfig{
Name: "some-provider-name",
Host: testHost,
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
ConnectionProtocol: TLS,
BindUsername: testBindUsername,
BindPassword: testBindPassword,
UserSearch: UserSearchConfig{
Base: testUserSearchBase,
UIDAttribute: testUserSearchUIDAttribute,
UsernameAttribute: testUserSearchUsernameAttribute,
},
GroupSearch: GroupSearchConfig{
Base: testGroupSearchBase,
Filter: testGroupSearchFilter,
GroupNameAttribute: testGroupSearchGroupNameAttribute,
},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
},
},
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{
{
DN: testGroupSearchResultDNValue1,
Attributes: []*ldap.EntryAttribute{
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue1}),
},
},
{
DN: testGroupSearchResultDNValue2,
Attributes: []*ldap.EntryAttribute{
ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue2}),
},
},
},
Referrals: []string{}, // note that we are not following referrals at this time
Controls: []ldap.Control{},
}, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2},
},
{
name: "happy path where group search returns no groups",
providerConfig: &ProviderConfig{
Name: "some-provider-name",
Host: testHost,
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
ConnectionProtocol: TLS,
BindUsername: testBindUsername,
BindPassword: testBindPassword,
UserSearch: UserSearchConfig{
Base: testUserSearchBase,
UIDAttribute: testUserSearchUIDAttribute,
UsernameAttribute: testUserSearchUsernameAttribute,
},
GroupSearch: GroupSearchConfig{
Base: testGroupSearchBase,
Filter: testGroupSearchFilter,
GroupNameAttribute: testGroupSearchGroupNameAttribute,
},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
},
},
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{},
Referrals: []string{}, // note that we are not following referrals at this time
Controls: []ldap.Control{},
}, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantGroups: []string{},
}, },
{ {
name: "error where dial fails", name: "error where dial fails",
@ -1421,6 +1563,37 @@ func TestUpstreamRefresh(t *testing.T) {
}, },
wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login", wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login",
}, },
{
name: "group search returns an error",
providerConfig: &ProviderConfig{
Name: "some-provider-name",
Host: testHost,
CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test
ConnectionProtocol: TLS,
BindUsername: testBindUsername,
BindPassword: testBindPassword,
UserSearch: UserSearchConfig{
Base: testUserSearchBase,
UIDAttribute: testUserSearchUIDAttribute,
UsernameAttribute: testUserSearchUsernameAttribute,
},
GroupSearch: GroupSearchConfig{
Base: testGroupSearchBase,
Filter: testGroupSearchFilter,
GroupNameAttribute: testGroupSearchGroupNameAttribute,
},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{
pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute),
},
},
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1)
conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(nil, errors.New("some search error")).Times(1)
conn.EXPECT().Close().Times(1)
},
wantErr: "error searching for group memberships for user with DN \"some-upstream-user-dn\": some search error",
},
} }
for _, tt := range tests { for _, tt := range tests {
@ -1435,9 +1608,9 @@ func TestUpstreamRefresh(t *testing.T) {
} }
dialWasAttempted := false dialWasAttempted := false
providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) { tt.providerConfig.Dialer = LDAPDialerFunc(func(ctx context.Context, addr endpointaddr.HostPort) (Conn, error) {
dialWasAttempted = true dialWasAttempted = true
require.Equal(t, providerConfig.Host, addr.Endpoint()) require.Equal(t, tt.providerConfig.Host, addr.Endpoint())
if tt.dialError != nil { if tt.dialError != nil {
return nil, tt.dialError return nil, tt.dialError
} }
@ -1446,9 +1619,9 @@ func TestUpstreamRefresh(t *testing.T) {
}) })
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000")) initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
ldapProvider := New(*providerConfig) ldapProvider := New(*tt.providerConfig)
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU" subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{ groups, err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{
Username: testUserSearchResultUsernameAttributeValue, Username: testUserSearchResultUsernameAttributeValue,
Subject: subject, Subject: subject,
DN: testUserSearchResultDNValue, DN: testUserSearchResultDNValue,
@ -1461,6 +1634,7 @@ func TestUpstreamRefresh(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
require.Equal(t, true, dialWasAttempted) require.Equal(t, true, dialWasAttempted)
require.Equal(t, tt.wantGroups, groups)
}) })
} }
} }

View File

@ -244,7 +244,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.Base = "" p.GroupSearch.Base = ""
})), })),
wantAuthResponse: &authenticators.Response{ wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil}, User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{}, ExtraRefreshAttributes: map[string]string{},
}, },
@ -257,7 +257,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.Base = "ou=users,dc=pinniped,dc=dev" // there are no groups under this part of the tree p.GroupSearch.Base = "ou=users,dc=pinniped,dc=dev" // there are no groups under this part of the tree
})), })),
wantAuthResponse: &authenticators.Response{ wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil}, User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{}, ExtraRefreshAttributes: map[string]string{},
}, },
@ -302,7 +302,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test p.GroupSearch.GroupNameAttribute = "objectClass" // silly example, but still a meaningful test
})), })),
wantAuthResponse: &authenticators.Response{ wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames", "groupOfNames"}}, User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{"groupOfNames"}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{}, ExtraRefreshAttributes: map[string]string{},
}, },
@ -328,7 +328,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.GroupSearch.Filter = "foobar={}" // foobar is not a valid attribute name for this LDAP server's schema p.GroupSearch.Filter = "foobar={}" // foobar is not a valid attribute name for this LDAP server's schema
})), })),
wantAuthResponse: &authenticators.Response{ wantAuthResponse: &authenticators.Response{
User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: nil}, User: &user.DefaultInfo{Name: "pinny", UID: b64("1000"), Groups: []string{}},
DN: "cn=pinny,ou=users,dc=pinniped,dc=dev", DN: "cn=pinny,ou=users,dc=pinniped,dc=dev",
ExtraRefreshAttributes: map[string]string{}, ExtraRefreshAttributes: map[string]string{},
}, },

View File

@ -60,12 +60,16 @@ func TestSupervisorLogin_Browser(t *testing.T) {
wantDownstreamIDTokenSubjectToMatch string wantDownstreamIDTokenSubjectToMatch string
wantDownstreamIDTokenUsernameToMatch func(username string) string wantDownstreamIDTokenUsernameToMatch func(username string) string
wantDownstreamIDTokenGroups []string wantDownstreamIDTokenGroups []string
wantDownstreamIDTokenGroupsAfterRefresh []string
wantErrorDescription string wantErrorDescription string
wantErrorType string wantErrorType string
// Either revoke the user's session on the upstream provider, or manipulate the user's session // Either revoke the user's session on the upstream provider, or manipulate the user's session
// data in such a way that it should cause the next upstream refresh attempt to fail. // data in such a way that it should cause the next upstream refresh attempt to fail.
breakRefreshSessionData func(t *testing.T, sessionData *psession.PinnipedSession, idpName, username string) breakRefreshSessionData func(t *testing.T, sessionData *psession.PinnipedSession, idpName, username string)
// Edit the refresh session data between the initial login and the refresh, which is expected to
// succeed.
editRefreshSessionDataWithoutBreaking func(t *testing.T, sessionData *psession.PinnipedSession, idpName, username string)
}{ }{
{ {
name: "oidc with default username and groups claim settings", name: "oidc with default username and groups claim settings",
@ -128,6 +132,14 @@ func TestSupervisorLogin_Browser(t *testing.T) {
wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Issuer+"?sub=") + ".+", wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Issuer+"?sub=") + ".+",
wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" }, wantDownstreamIDTokenUsernameToMatch: func(_ string) string { return "^" + regexp.QuoteMeta(env.SupervisorUpstreamOIDC.Username) + "$" },
wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups, wantDownstreamIDTokenGroups: env.SupervisorUpstreamOIDC.ExpectedGroups,
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) {
// even if we update this group to the wrong thing, we expect that it will return to the correct
// value after we refresh.
// However if there are no expected groups then they will not update, so we should skip this.
if len(env.SupervisorUpstreamOIDC.ExpectedGroups) > 0 {
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
}
},
}, },
{ {
name: "oidc without refresh token", name: "oidc without refresh token",
@ -272,6 +284,11 @@ func TestSupervisorLogin_Browser(t *testing.T) {
false, false,
) )
}, },
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) {
// even if we update this group to the wrong thing, we expect that it will return to the correct
// value after we refresh.
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
},
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) { breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
customSessionData := pinnipedSession.Custom customSessionData := pinnipedSession.Custom
require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType) require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType)
@ -291,6 +308,94 @@ func TestSupervisorLogin_Browser(t *testing.T) {
}, },
wantDownstreamIDTokenGroups: env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs, wantDownstreamIDTokenGroups: env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs,
}, },
{
name: "ldap with email as username and group search base that doesn't return anything, and using an LDAP provider which supports TLS",
maybeSkip: func(t *testing.T) {
t.Helper()
if len(env.ToolsNamespace) == 0 && !env.HasCapability(testlib.CanReachInternetLDAPPorts) {
t.Skip("LDAP integration test requires connectivity to an LDAP server")
}
if env.SupervisorUpstreamLDAP.UserSearchBase == env.SupervisorUpstreamLDAP.GroupSearchBase {
// This test relies on using the user search base as the group search base, to simulate
// searching for groups and not finding any.
// If the users and groups are stored in the same place, then we will get groups
// back, so this test wouldn't make sense.
t.Skip("must have a different user search base than group search base")
}
},
createIDP: func(t *testing.T) string {
t.Helper()
secret := testlib.CreateTestSecret(t, env.SupervisorNamespace, "ldap-service-account", v1.SecretTypeBasicAuth,
map[string]string{
v1.BasicAuthUsernameKey: env.SupervisorUpstreamLDAP.BindUsername,
v1.BasicAuthPasswordKey: env.SupervisorUpstreamLDAP.BindPassword,
},
)
ldapIDP := testlib.CreateTestLDAPIdentityProvider(t, idpv1alpha1.LDAPIdentityProviderSpec{
Host: env.SupervisorUpstreamLDAP.Host,
TLS: &idpv1alpha1.TLSSpec{
CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamLDAP.CABundle)),
},
Bind: idpv1alpha1.LDAPIdentityProviderBind{
SecretName: secret.Name,
},
UserSearch: idpv1alpha1.LDAPIdentityProviderUserSearch{
Base: env.SupervisorUpstreamLDAP.UserSearchBase,
Filter: "",
Attributes: idpv1alpha1.LDAPIdentityProviderUserSearchAttributes{
Username: env.SupervisorUpstreamLDAP.TestUserMailAttributeName,
UID: env.SupervisorUpstreamLDAP.TestUserUniqueIDAttributeName,
},
},
GroupSearch: idpv1alpha1.LDAPIdentityProviderGroupSearch{
Base: env.SupervisorUpstreamLDAP.UserSearchBase, // groups not stored at the user search base
Filter: "",
Attributes: idpv1alpha1.LDAPIdentityProviderGroupSearchAttributes{
GroupName: "dn",
},
},
}, idpv1alpha1.LDAPPhaseReady)
expectedMsg := fmt.Sprintf(
`successfully able to connect to "%s" and bind as user "%s" [validated with Secret "%s" at version "%s"]`,
env.SupervisorUpstreamLDAP.Host, env.SupervisorUpstreamLDAP.BindUsername,
secret.Name, secret.ResourceVersion,
)
requireSuccessfulLDAPIdentityProviderConditions(t, ldapIDP, expectedMsg)
return ldapIDP.Name
},
requestAuthorization: func(t *testing.T, downstreamAuthorizeURL, _, _, _ string, httpClient *http.Client) {
requestAuthorizationUsingCLIPasswordFlow(t,
downstreamAuthorizeURL,
env.SupervisorUpstreamLDAP.TestUserMailAttributeValue, // username to present to server during login
env.SupervisorUpstreamLDAP.TestUserPassword, // password to present to server during login
httpClient,
false,
)
},
editRefreshSessionDataWithoutBreaking: func(t *testing.T, sessionData *psession.PinnipedSession, _, _ string) {
// even if we update this group to the wrong thing, we expect that it will return to the correct
// value after we refresh.
sessionData.Fosite.Claims.Extra["groups"] = []string{"some-wrong-group", "some-other-group"}
},
breakRefreshSessionData: func(t *testing.T, pinnipedSession *psession.PinnipedSession, _, _ string) {
customSessionData := pinnipedSession.Custom
require.Equal(t, psession.ProviderTypeLDAP, customSessionData.ProviderType)
require.NotEmpty(t, customSessionData.LDAP.UserDN)
fositeSessionData := pinnipedSession.Fosite
fositeSessionData.Claims.Subject = "not-right"
},
// the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute
wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(
"ldaps://"+env.SupervisorUpstreamLDAP.Host+
"?base="+url.QueryEscape(env.SupervisorUpstreamLDAP.UserSearchBase)+
"&sub="+base64.RawURLEncoding.EncodeToString([]byte(env.SupervisorUpstreamLDAP.TestUserUniqueIDAttributeValue)),
) + "$",
// the ID token Username should have been pulled from the requested UserSearch.Attributes.Username attribute
wantDownstreamIDTokenUsernameToMatch: func(_ string) string {
return "^" + regexp.QuoteMeta(env.SupervisorUpstreamLDAP.TestUserMailAttributeValue) + "$"
},
wantDownstreamIDTokenGroups: []string{},
},
{ {
name: "ldap with CN as username and group names as CNs and using an LDAP provider which only supports StartTLS", // try another variation of configuration options name: "ldap with CN as username and group names as CNs and using an LDAP provider which only supports StartTLS", // try another variation of configuration options
maybeSkip: func(t *testing.T) { maybeSkip: func(t *testing.T) {
@ -1241,7 +1346,6 @@ func TestSupervisorLogin_Browser(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// Create the LDAPIdentityProvider using GenerateName to get a random name.
upstreams := client.IDPV1alpha1().LDAPIdentityProviders(env.SupervisorNamespace) upstreams := client.IDPV1alpha1().LDAPIdentityProviders(env.SupervisorNamespace)
ldapIDP, err := upstreams.Get(ctx, idpName, metav1.GetOptions{}) ldapIDP, err := upstreams.Get(ctx, idpName, metav1.GetOptions{})
require.NoError(t, err) require.NoError(t, err)
@ -1263,6 +1367,87 @@ func TestSupervisorLogin_Browser(t *testing.T) {
}, },
wantDownstreamIDTokenGroups: env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs, wantDownstreamIDTokenGroups: env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs,
}, },
{
name: "ldap refresh updates groups to be empty after deleting the group search base",
maybeSkip: func(t *testing.T) {
t.Helper()
if len(env.ToolsNamespace) == 0 && !env.HasCapability(testlib.CanReachInternetLDAPPorts) {
t.Skip("LDAP integration test requires connectivity to an LDAP server")
}
},
createIDP: func(t *testing.T) string {
t.Helper()
secret := testlib.CreateTestSecret(t, env.SupervisorNamespace, "ldap-service-account", v1.SecretTypeBasicAuth,
map[string]string{
v1.BasicAuthUsernameKey: env.SupervisorUpstreamLDAP.BindUsername,
v1.BasicAuthPasswordKey: env.SupervisorUpstreamLDAP.BindPassword,
},
)
secretName := secret.Name
ldapIDP := testlib.CreateTestLDAPIdentityProvider(t, idpv1alpha1.LDAPIdentityProviderSpec{
Host: env.SupervisorUpstreamLDAP.Host,
TLS: &idpv1alpha1.TLSSpec{
CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamLDAP.CABundle)),
},
Bind: idpv1alpha1.LDAPIdentityProviderBind{
SecretName: secretName,
},
UserSearch: idpv1alpha1.LDAPIdentityProviderUserSearch{
Base: env.SupervisorUpstreamLDAP.UserSearchBase,
Filter: "",
Attributes: idpv1alpha1.LDAPIdentityProviderUserSearchAttributes{
Username: env.SupervisorUpstreamLDAP.TestUserMailAttributeName,
UID: env.SupervisorUpstreamLDAP.TestUserUniqueIDAttributeName,
},
},
GroupSearch: idpv1alpha1.LDAPIdentityProviderGroupSearch{
Base: env.SupervisorUpstreamLDAP.GroupSearchBase,
Filter: "",
Attributes: idpv1alpha1.LDAPIdentityProviderGroupSearchAttributes{
GroupName: "dn",
},
},
}, idpv1alpha1.LDAPPhaseReady)
return ldapIDP.Name
},
requestAuthorization: func(t *testing.T, downstreamAuthorizeURL, _, _, _ string, httpClient *http.Client) {
requestAuthorizationUsingCLIPasswordFlow(t,
downstreamAuthorizeURL,
env.SupervisorUpstreamLDAP.TestUserMailAttributeValue, // username to present to server during login
env.SupervisorUpstreamLDAP.TestUserPassword, // password to present to server during login
httpClient,
false,
)
},
editRefreshSessionDataWithoutBreaking: func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName, _ string) {
// get the idp, update the config.
client := testlib.NewSupervisorClientset(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
upstreams := client.IDPV1alpha1().LDAPIdentityProviders(env.SupervisorNamespace)
ldapIDP, err := upstreams.Get(ctx, idpName, metav1.GetOptions{})
require.NoError(t, err)
ldapIDP.Spec.GroupSearch.Base = ""
_, err = upstreams.Update(ctx, ldapIDP, metav1.UpdateOptions{})
require.NoError(t, err)
time.Sleep(10 * time.Second) // wait for controllers to pick up the change
},
// the ID token Subject should be the Host URL plus the value pulled from the requested UserSearch.Attributes.UID attribute
wantDownstreamIDTokenSubjectToMatch: "^" + regexp.QuoteMeta(
"ldaps://"+env.SupervisorUpstreamLDAP.Host+
"?base="+url.QueryEscape(env.SupervisorUpstreamLDAP.UserSearchBase)+
"&sub="+base64.RawURLEncoding.EncodeToString([]byte(env.SupervisorUpstreamLDAP.TestUserUniqueIDAttributeValue)),
) + "$",
// the ID token Username should have been pulled from the requested UserSearch.Attributes.Username attribute
wantDownstreamIDTokenUsernameToMatch: func(_ string) string {
return "^" + regexp.QuoteMeta(env.SupervisorUpstreamLDAP.TestUserMailAttributeValue) + "$"
},
wantDownstreamIDTokenGroups: env.SupervisorUpstreamLDAP.TestUserDirectGroupsDNs,
wantDownstreamIDTokenGroupsAfterRefresh: []string{},
},
} }
for _, test := range tests { for _, test := range tests {
tt := test tt := test
@ -1272,12 +1457,14 @@ func TestSupervisorLogin_Browser(t *testing.T) {
testSupervisorLogin(t, testSupervisorLogin(t,
tt.createIDP, tt.createIDP,
tt.requestAuthorization, tt.requestAuthorization,
tt.editRefreshSessionDataWithoutBreaking,
tt.breakRefreshSessionData, tt.breakRefreshSessionData,
tt.createTestUser, tt.createTestUser,
tt.deleteTestUser, tt.deleteTestUser,
tt.wantDownstreamIDTokenSubjectToMatch, tt.wantDownstreamIDTokenSubjectToMatch,
tt.wantDownstreamIDTokenUsernameToMatch, tt.wantDownstreamIDTokenUsernameToMatch,
tt.wantDownstreamIDTokenGroups, tt.wantDownstreamIDTokenGroups,
tt.wantDownstreamIDTokenGroupsAfterRefresh,
tt.wantErrorDescription, tt.wantErrorDescription,
tt.wantErrorType, tt.wantErrorType,
) )
@ -1405,13 +1592,15 @@ func requireEventuallySuccessfulActiveDirectoryIdentityProviderConditions(t *tes
func testSupervisorLogin( func testSupervisorLogin(
t *testing.T, t *testing.T,
createIDP func(t *testing.T) string, createIDP func(t *testing.T) string,
requestAuthorization func(t *testing.T, downstreamAuthorizeURL, downstreamCallbackURL, username, password string, httpClient *http.Client), requestAuthorization func(t *testing.T, downstreamAuthorizeURL string, downstreamCallbackURL string, username string, password string, httpClient *http.Client),
editRefreshSessionDataWithoutBreaking func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName, username string),
breakRefreshSessionData func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName, username string), breakRefreshSessionData func(t *testing.T, pinnipedSession *psession.PinnipedSession, idpName, username string),
createTestUser func(t *testing.T) (string, string), createTestUser func(t *testing.T) (string, string),
deleteTestUser func(t *testing.T, username string), deleteTestUser func(t *testing.T, username string),
wantDownstreamIDTokenSubjectToMatch string, wantDownstreamIDTokenSubjectToMatch string,
wantDownstreamIDTokenUsernameToMatch func(username string) string, wantDownstreamIDTokenUsernameToMatch func(username string) string,
wantDownstreamIDTokenGroups []string, wantDownstreamIDTokenGroups []string,
wantDownstreamIDTokenGroupsAfterRefresh []string,
wantErrorDescription string, wantErrorDescription string,
wantErrorType string, wantErrorType string,
) { ) {
@ -1565,6 +1754,28 @@ func testSupervisorLogin(
// token exchange on the original token // token exchange on the original token
doTokenExchange(t, &downstreamOAuth2Config, tokenResponse, httpClient, discovery) doTokenExchange(t, &downstreamOAuth2Config, tokenResponse, httpClient, discovery)
if editRefreshSessionDataWithoutBreaking != nil {
latestRefreshToken := tokenResponse.RefreshToken
signatureOfLatestRefreshToken := getFositeDataSignature(t, latestRefreshToken)
// First use the latest downstream refresh token to look up the corresponding session in the Supervisor's storage.
kubeClient := testlib.NewKubernetesClientset(t)
supervisorSecretsClient := kubeClient.CoreV1().Secrets(env.SupervisorNamespace)
oauthStore := oidc.NewKubeStorage(supervisorSecretsClient, oidc.DefaultOIDCTimeoutsConfiguration())
storedRefreshSession, err := oauthStore.GetRefreshTokenSession(ctx, signatureOfLatestRefreshToken, nil)
require.NoError(t, err)
// Next mutate the part of the session that is used during upstream refresh.
pinnipedSession, ok := storedRefreshSession.GetSession().(*psession.PinnipedSession)
require.True(t, ok, "should have been able to cast session data to PinnipedSession")
editRefreshSessionDataWithoutBreaking(t, pinnipedSession, idpName, username)
// Then save the mutated Secret back to Kubernetes.
// There is no update function, so delete and create again at the same name.
require.NoError(t, oauthStore.DeleteRefreshTokenSession(ctx, signatureOfLatestRefreshToken))
require.NoError(t, oauthStore.CreateRefreshTokenSession(ctx, signatureOfLatestRefreshToken, storedRefreshSession))
}
// Use the refresh token to get new tokens // Use the refresh token to get new tokens
refreshSource := downstreamOAuth2Config.TokenSource(oidcHTTPClientContext, &oauth2.Token{RefreshToken: tokenResponse.RefreshToken}) refreshSource := downstreamOAuth2Config.TokenSource(oidcHTTPClientContext, &oauth2.Token{RefreshToken: tokenResponse.RefreshToken})
refreshedTokenResponse, err := refreshSource.Token() refreshedTokenResponse, err := refreshSource.Token()
@ -1572,9 +1783,12 @@ func testSupervisorLogin(
// When refreshing, expect to get an "at_hash" claim, but no "nonce" claim. // When refreshing, expect to get an "at_hash" claim, but no "nonce" claim.
expectRefreshedIDTokenClaims := []string{"iss", "exp", "sub", "aud", "auth_time", "iat", "jti", "rat", "username", "groups", "at_hash"} expectRefreshedIDTokenClaims := []string{"iss", "exp", "sub", "aud", "auth_time", "iat", "jti", "rat", "username", "groups", "at_hash"}
if wantDownstreamIDTokenGroupsAfterRefresh == nil {
wantDownstreamIDTokenGroupsAfterRefresh = wantDownstreamIDTokenGroups
}
verifyTokenResponse(t, verifyTokenResponse(t,
refreshedTokenResponse, discovery, downstreamOAuth2Config, "", refreshedTokenResponse, discovery, downstreamOAuth2Config, "",
expectRefreshedIDTokenClaims, wantDownstreamIDTokenSubjectToMatch, wantDownstreamIDTokenUsernameToMatch(username), wantDownstreamIDTokenGroups) expectRefreshedIDTokenClaims, wantDownstreamIDTokenSubjectToMatch, wantDownstreamIDTokenUsernameToMatch(username), wantDownstreamIDTokenGroupsAfterRefresh)
require.NotEqual(t, tokenResponse.AccessToken, refreshedTokenResponse.AccessToken) require.NotEqual(t, tokenResponse.AccessToken, refreshedTokenResponse.AccessToken)
require.NotEqual(t, tokenResponse.RefreshToken, refreshedTokenResponse.RefreshToken) require.NotEqual(t, tokenResponse.RefreshToken, refreshedTokenResponse.RefreshToken)