Add a timeout to the upstream OIDC refresh calls
Unfortunately this means that we lose test coverage on passing the request's context through to the mocked PerformRefresh function. We used to be able to assert equality on the request's context because it was passed through unchanged.
This commit is contained in:
parent
3301a62053
commit
ba83c12f93
@ -240,7 +240,10 @@ func performUpstreamOIDCRefreshWithRetriesOnError(
|
||||
|
||||
performRefreshOnce := func() error {
|
||||
var err error
|
||||
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
||||
// Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects.
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -950,7 +950,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamRefresh{
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
}
|
||||
@ -960,7 +959,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamRefresh{
|
||||
performedByUpstreamName: ldapUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil,
|
||||
DN: ldapUpstreamDN,
|
||||
ExpectedSubject: goodSubject,
|
||||
ExpectedUsername: goodUsername,
|
||||
@ -972,7 +970,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamRefresh{
|
||||
performedByUpstreamName: activeDirectoryUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil,
|
||||
DN: activeDirectoryUpstreamDN,
|
||||
ExpectedSubject: goodSubject,
|
||||
ExpectedUsername: goodUsername,
|
||||
@ -984,7 +981,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamValidateTokens{
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
Tok: expectedTokens,
|
||||
ExpectedIDTokenNonce: "", // always expect empty string
|
||||
RequireUserInfo: false,
|
||||
@ -1149,7 +1145,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
|
||||
oidcUpstreamName,
|
||||
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
|
||||
ExpectedIDTokenNonce: "", // always expect empty string
|
||||
RequireIDToken: false,
|
||||
@ -1748,7 +1743,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
},
|
||||
@ -1777,7 +1771,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
},
|
||||
@ -2713,9 +2706,8 @@ func TestRefreshGrant(t *testing.T) {
|
||||
if test.modifyRefreshTokenStorage != nil {
|
||||
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
|
||||
}
|
||||
reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context")
|
||||
req := httptest.NewRequest("POST", "/path/shouldn't/matter",
|
||||
happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext)
|
||||
happyRefreshRequestBody(firstRefreshToken).ReadCloser())
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if test.refreshRequest.modifyTokenRequest != nil {
|
||||
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
|
||||
@ -2728,7 +2720,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
|
||||
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
|
||||
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
|
||||
test.idps.RequireExactlyNCallsToPerformRefresh(t,
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
|
||||
@ -2741,7 +2732,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the
|
||||
// new ID token that was returned by the upstream refresh.
|
||||
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
|
||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext
|
||||
test.idps.RequireExactlyOneCallToValidateToken(t,
|
||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
|
||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,
|
||||
|
@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
|
||||
// PerformRefreshArgs is used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
|
||||
type PerformRefreshArgs struct {
|
||||
Ctx context.Context
|
||||
RefreshToken string
|
||||
DN string
|
||||
ExpectedUsername string
|
||||
@ -79,7 +78,6 @@ type RevokeTokenArgs struct {
|
||||
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
|
||||
type ValidateTokenAndMergeWithUserInfoArgs struct {
|
||||
Ctx context.Context
|
||||
Tok *oauth2.Token
|
||||
ExpectedIDTokenNonce nonce.Nonce
|
||||
RequireIDToken bool
|
||||
@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
|
||||
return u.URL
|
||||
}
|
||||
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
if u.performRefreshArgs == nil {
|
||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||
}
|
||||
u.performRefreshCallCount++
|
||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||
Ctx: ctx,
|
||||
DN: storedRefreshAttributes.DN,
|
||||
ExpectedUsername: storedRefreshAttributes.Username,
|
||||
ExpectedSubject: storedRefreshAttributes.Subject,
|
||||
@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
|
||||
}
|
||||
u.performRefreshCallCount++
|
||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||
Ctx: ctx,
|
||||
RefreshToken: refreshToken,
|
||||
})
|
||||
return u.PerformRefreshFunc(ctx, refreshToken)
|
||||
@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx
|
||||
}
|
||||
u.validateTokenAndMergeWithUserInfoCallCount++
|
||||
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
|
||||
Ctx: ctx,
|
||||
Tok: tok,
|
||||
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
||||
RequireIDToken: requireIDToken,
|
||||
|
Loading…
Reference in New Issue
Block a user