Refactor out a function, add tests for getting the wrong idp uid

This commit is contained in:
Margo Crawford 2021-10-26 17:03:16 -07:00
parent 722b5dcc1b
commit 84edfcb541
3 changed files with 80 additions and 21 deletions

View File

@ -75,15 +75,10 @@ func NewHandler(
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
session := accessRequest.GetSession().(*psession.PinnipedSession) session := accessRequest.GetSession().(*psession.PinnipedSession)
extra := session.Fosite.Claims.Extra downstreamUsername, err := getDownstreamUsernameFromPinnipedSession(session)
if extra == nil { if err != nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return err
} }
downstreamUsernameInterface := extra["username"]
if downstreamUsernameInterface == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
downstreamUsername := downstreamUsernameInterface.(string)
downstreamSubject := session.Fosite.Claims.Subject downstreamSubject := session.Fosite.Claims.Subject
customSessionData := session.Custom customSessionData := session.Custom
@ -225,3 +220,16 @@ func findLDAPProviderByNameAndValidateUID(
return nil, "", errorsx.WithStack(errUpstreamRefreshError. return nil, "", errorsx.WithStack(errUpstreamRefreshError.
WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType))
} }
func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession) (string, error) {
extra := session.Fosite.Claims.Extra
if extra == nil {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
downstreamUsernameInterface := extra["username"]
if downstreamUsernameInterface == nil {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
downstreamUsername := downstreamUsernameInterface.(string)
return downstreamUsername, nil
}

View File

@ -1647,7 +1647,7 @@ func TestRefreshGrant(t *testing.T) {
ProviderUID: activeDirectoryUpstreamResourceUID, ProviderUID: activeDirectoryUpstreamResourceUID,
ProviderName: activeDirectoryUpstreamName, ProviderName: activeDirectoryUpstreamName,
ProviderType: activeDirectoryUpstreamType, ProviderType: activeDirectoryUpstreamType,
LDAP: nil, ActiveDirectory: nil,
}, },
), ),
}, },
@ -1705,9 +1705,9 @@ func TestRefreshGrant(t *testing.T) {
}, },
{ {
name: "upstream active directory refresh when the active directory session data does not contain dn", name: "upstream active directory refresh when the active directory session data does not contain dn",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ idps: oidctestutil.NewUpstreamIDPListerBuilder().WithActiveDirectory(&oidctestutil.TestUpstreamLDAPIdentityProvider{
Name: ldapUpstreamName, Name: activeDirectoryUpstreamName,
ResourceUID: ldapUpstreamResourceUID, ResourceUID: activeDirectoryUpstreamResourceUID,
URL: ldapUpstreamURL, URL: ldapUpstreamURL,
}), }),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
@ -1716,7 +1716,7 @@ func TestRefreshGrant(t *testing.T) {
ProviderUID: ldapUpstreamResourceUID, ProviderUID: ldapUpstreamResourceUID,
ProviderName: ldapUpstreamName, ProviderName: ldapUpstreamName,
ProviderType: ldapUpstreamType, ProviderType: ldapUpstreamType,
LDAP: &psession.LDAPSessionData{ ActiveDirectory: &psession.ActiveDirectorySessionData{
UserDN: "", UserDN: "",
}, },
}, },
@ -1725,7 +1725,7 @@ func TestRefreshGrant(t *testing.T) {
ProviderUID: ldapUpstreamResourceUID, ProviderUID: ldapUpstreamResourceUID,
ProviderName: ldapUpstreamName, ProviderName: ldapUpstreamName,
ProviderType: ldapUpstreamType, ProviderType: ldapUpstreamType,
LDAP: &psession.LDAPSessionData{ ActiveDirectory: &psession.ActiveDirectorySessionData{
UserDN: "", UserDN: "",
}, },
}, },
@ -1922,6 +1922,58 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
}, },
{
name: "when the ldap provider in the session storage is found but has the wrong resource UID during the refresh request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
Name: ldapUpstreamName,
ResourceUID: "the-wrong-uid",
URL: ldapUpstreamURL,
}),
authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
customSessionData: happyLDAPCustomSessionData,
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
happyLDAPCustomSessionData,
),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Provider 'some-ldap-idp' of type 'ldap' from upstream session data has changed its resource UID since authentication."
}
`),
},
},
},
{
name: "when the active directory provider in the session storage is found but has the wrong resource UID during the refresh request",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithActiveDirectory(&oidctestutil.TestUpstreamLDAPIdentityProvider{
Name: activeDirectoryUpstreamName,
ResourceUID: "the-wrong-uid",
URL: ldapUpstreamURL,
}),
authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
customSessionData: happyActiveDirectoryCustomSessionData,
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
happyActiveDirectoryCustomSessionData,
),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Provider 'some-ad-idp' of type 'activedirectory' from upstream session data has changed its resource UID since authentication."
}
`),
},
},
},
} }
for _, test := range tests { for _, test := range tests {
test := test test := test

View File

@ -210,7 +210,7 @@ func (p *Provider) PerformRefresh(ctx context.Context, userDN string, expectedUs
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN) newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
if err != nil { if err != nil {
return err // TODO test having no values or more than one maybe return err
} }
if newUsername != expectedUsername { if newUsername != expectedUsername {
return fmt.Errorf(`searching for user "%s" returned a different username than the previous value. expected: "%s", actual: "%s"`, return fmt.Errorf(`searching for user "%s" returned a different username than the previous value. expected: "%s", actual: "%s"`,
@ -220,15 +220,14 @@ func (p *Provider) PerformRefresh(ctx context.Context, userDN string, expectedUs
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN) newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
if err != nil { if err != nil {
return err // TODO test return err
} }
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL()) newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
if newSubject != expectedSubject { if newSubject != expectedSubject {
return fmt.Errorf(`searching for user "%s" produced a different subject than the previous value. expected: "%s", actual: "%s"`, userDN, expectedSubject, newSubject) return fmt.Errorf(`searching for user "%s" produced a different subject than the previous value. expected: "%s", actual: "%s"`, userDN, expectedSubject, newSubject)
} }
// do nothing. if we got exactly one search result back then that means the user // we checked that the user still exists and their information is the same, so just return.
// still exists.
return nil return nil
} }