From ba83c12f939bddd2ea736d1ed51bb5ec143d89da Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 20 Jan 2022 13:11:05 -0800 Subject: [PATCH] Add a timeout to the upstream OIDC refresh calls Unfortunately this means that we lose test coverage on passing the request's context through to the mocked PerformRefresh function. We used to be able to assert equality on the request's context because it was passed through unchanged. --- internal/oidc/token/token_handler.go | 5 ++++- internal/oidc/token/token_handler_test.go | 12 +----------- internal/testutil/oidctestutil/oidctestutil.go | 7 +------ 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 40a8061d..2b2d7a7d 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -240,7 +240,10 @@ func performUpstreamOIDCRefreshWithRetriesOnError( performRefreshOnce := func() error { var err error - tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) + // Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects. + timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second) + defer cancel() + tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken) return err } diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index d771bcab..901e097e 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -950,7 +950,6 @@ func TestRefreshGrant(t *testing.T) { return &expectedUpstreamRefresh{ performedByUpstreamName: oidcUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, // this will be filled in with the actual request context by the test below RefreshToken: oidcUpstreamInitialRefreshToken, }, } @@ -960,7 +959,6 @@ func TestRefreshGrant(t *testing.T) { return &expectedUpstreamRefresh{ performedByUpstreamName: ldapUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, DN: ldapUpstreamDN, ExpectedSubject: goodSubject, ExpectedUsername: goodUsername, @@ -972,7 +970,6 @@ func TestRefreshGrant(t *testing.T) { return &expectedUpstreamRefresh{ performedByUpstreamName: activeDirectoryUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, DN: activeDirectoryUpstreamDN, ExpectedSubject: goodSubject, ExpectedUsername: goodUsername, @@ -984,7 +981,6 @@ func TestRefreshGrant(t *testing.T) { return &expectedUpstreamValidateTokens{ performedByUpstreamName: oidcUpstreamName, args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ - Ctx: nil, // this will be filled in with the actual request context by the test below Tok: expectedTokens, ExpectedIDTokenNonce: "", // always expect empty string RequireUserInfo: false, @@ -1149,7 +1145,6 @@ func TestRefreshGrant(t *testing.T) { wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{ oidcUpstreamName, &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ - Ctx: nil, // this will be filled in with the actual request context by the test below Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token ExpectedIDTokenNonce: "", // always expect empty string RequireIDToken: false, @@ -1748,7 +1743,6 @@ func TestRefreshGrant(t *testing.T) { numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries performedByUpstreamName: oidcUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, // this will be filled in with the actual request context by the test below RefreshToken: oidcUpstreamInitialRefreshToken, }, }, @@ -1777,7 +1771,6 @@ func TestRefreshGrant(t *testing.T) { numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries performedByUpstreamName: oidcUpstreamName, args: &oidctestutil.PerformRefreshArgs{ - Ctx: nil, // this will be filled in with the actual request context by the test below RefreshToken: oidcUpstreamInitialRefreshToken, }, }, @@ -2713,9 +2706,8 @@ func TestRefreshGrant(t *testing.T) { if test.modifyRefreshTokenStorage != nil { test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken) } - reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") req := httptest.NewRequest("POST", "/path/shouldn't/matter", - happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext) + happyRefreshRequestBody(firstRefreshToken).ReadCloser()) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") if test.refreshRequest.modifyTokenRequest != nil { test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string)) @@ -2728,7 +2720,6 @@ func TestRefreshGrant(t *testing.T) { // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. if test.refreshRequest.want.wantUpstreamRefreshCall != nil { - test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext test.idps.RequireExactlyNCallsToPerformRefresh(t, test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName, @@ -2741,7 +2732,6 @@ func TestRefreshGrant(t *testing.T) { // Test that we did or did not make a call to the upstream OIDC provider interface to validate the // new ID token that was returned by the upstream refresh. if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil { - test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext test.idps.RequireExactlyOneCallToValidateToken(t, test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args, diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 3f1a5752..3bf0cd98 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct { // PerformRefreshArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). type PerformRefreshArgs struct { - Ctx context.Context RefreshToken string DN string ExpectedUsername string @@ -79,7 +78,6 @@ type RevokeTokenArgs struct { // ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc(). type ValidateTokenAndMergeWithUserInfoArgs struct { - Ctx context.Context Tok *oauth2.Token ExpectedIDTokenNonce nonce.Nonce RequireIDToken bool @@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL { return u.URL } -func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error { +func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error { if u.performRefreshArgs == nil { u.performRefreshArgs = make([]*PerformRefreshArgs, 0) } u.performRefreshCallCount++ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ - Ctx: ctx, DN: storedRefreshAttributes.DN, ExpectedUsername: storedRefreshAttributes.Username, ExpectedSubject: storedRefreshAttributes.Subject, @@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r } u.performRefreshCallCount++ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ - Ctx: ctx, RefreshToken: refreshToken, }) return u.PerformRefreshFunc(ctx, refreshToken) @@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx } u.validateTokenAndMergeWithUserInfoCallCount++ u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{ - Ctx: ctx, Tok: tok, ExpectedIDTokenNonce: expectedIDTokenNonce, RequireIDToken: requireIDToken,