Add a timeout to the upstream OIDC refresh calls

Unfortunately this means that we lose test coverage on passing the
request's context through to the mocked PerformRefresh function. We
used to be able to assert equality on the request's context because it
was passed through unchanged.
This commit is contained in:
Ryan Richard 2022-01-20 13:11:05 -08:00
parent 3301a62053
commit ba83c12f93
3 changed files with 6 additions and 18 deletions

View File

@ -240,7 +240,10 @@ func performUpstreamOIDCRefreshWithRetriesOnError(
performRefreshOnce := func() error { performRefreshOnce := func() error {
var err error var err error
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) // Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects.
timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken)
return err return err
} }

View File

@ -950,7 +950,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{ return &expectedUpstreamRefresh{
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken, RefreshToken: oidcUpstreamInitialRefreshToken,
}, },
} }
@ -960,7 +959,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{ return &expectedUpstreamRefresh{
performedByUpstreamName: ldapUpstreamName, performedByUpstreamName: ldapUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil,
DN: ldapUpstreamDN, DN: ldapUpstreamDN,
ExpectedSubject: goodSubject, ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername, ExpectedUsername: goodUsername,
@ -972,7 +970,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{ return &expectedUpstreamRefresh{
performedByUpstreamName: activeDirectoryUpstreamName, performedByUpstreamName: activeDirectoryUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil,
DN: activeDirectoryUpstreamDN, DN: activeDirectoryUpstreamDN,
ExpectedSubject: goodSubject, ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername, ExpectedUsername: goodUsername,
@ -984,7 +981,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamValidateTokens{ return &expectedUpstreamValidateTokens{
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: expectedTokens, Tok: expectedTokens,
ExpectedIDTokenNonce: "", // always expect empty string ExpectedIDTokenNonce: "", // always expect empty string
RequireUserInfo: false, RequireUserInfo: false,
@ -1149,7 +1145,6 @@ func TestRefreshGrant(t *testing.T) {
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{ wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
oidcUpstreamName, oidcUpstreamName,
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
ExpectedIDTokenNonce: "", // always expect empty string ExpectedIDTokenNonce: "", // always expect empty string
RequireIDToken: false, RequireIDToken: false,
@ -1748,7 +1743,6 @@ func TestRefreshGrant(t *testing.T) {
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken, RefreshToken: oidcUpstreamInitialRefreshToken,
}, },
}, },
@ -1777,7 +1771,6 @@ func TestRefreshGrant(t *testing.T) {
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken, RefreshToken: oidcUpstreamInitialRefreshToken,
}, },
}, },
@ -2713,9 +2706,8 @@ func TestRefreshGrant(t *testing.T) {
if test.modifyRefreshTokenStorage != nil { if test.modifyRefreshTokenStorage != nil {
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken) test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
} }
reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context")
req := httptest.NewRequest("POST", "/path/shouldn't/matter", req := httptest.NewRequest("POST", "/path/shouldn't/matter",
happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext) happyRefreshRequestBody(firstRefreshToken).ReadCloser())
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if test.refreshRequest.modifyTokenRequest != nil { if test.refreshRequest.modifyTokenRequest != nil {
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string)) test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
@ -2728,7 +2720,6 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
if test.refreshRequest.want.wantUpstreamRefreshCall != nil { if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
test.idps.RequireExactlyNCallsToPerformRefresh(t, test.idps.RequireExactlyNCallsToPerformRefresh(t,
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
@ -2741,7 +2732,6 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the // Test that we did or did not make a call to the upstream OIDC provider interface to validate the
// new ID token that was returned by the upstream refresh. // new ID token that was returned by the upstream refresh.
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil { if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext
test.idps.RequireExactlyOneCallToValidateToken(t, test.idps.RequireExactlyOneCallToValidateToken(t,
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args, test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,

View File

@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
// PerformRefreshArgs is used to spy on calls to // PerformRefreshArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). // TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
type PerformRefreshArgs struct { type PerformRefreshArgs struct {
Ctx context.Context
RefreshToken string RefreshToken string
DN string DN string
ExpectedUsername string ExpectedUsername string
@ -79,7 +78,6 @@ type RevokeTokenArgs struct {
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to // ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc(). // TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
type ValidateTokenAndMergeWithUserInfoArgs struct { type ValidateTokenAndMergeWithUserInfoArgs struct {
Ctx context.Context
Tok *oauth2.Token Tok *oauth2.Token
ExpectedIDTokenNonce nonce.Nonce ExpectedIDTokenNonce nonce.Nonce
RequireIDToken bool RequireIDToken bool
@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
return u.URL return u.URL
} }
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error { func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
if u.performRefreshArgs == nil { if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0) u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
} }
u.performRefreshCallCount++ u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
DN: storedRefreshAttributes.DN, DN: storedRefreshAttributes.DN,
ExpectedUsername: storedRefreshAttributes.Username, ExpectedUsername: storedRefreshAttributes.Username,
ExpectedSubject: storedRefreshAttributes.Subject, ExpectedSubject: storedRefreshAttributes.Subject,
@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
} }
u.performRefreshCallCount++ u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
RefreshToken: refreshToken, RefreshToken: refreshToken,
}) })
return u.PerformRefreshFunc(ctx, refreshToken) return u.PerformRefreshFunc(ctx, refreshToken)
@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx
} }
u.validateTokenAndMergeWithUserInfoCallCount++ u.validateTokenAndMergeWithUserInfoCallCount++
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{ u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
Ctx: ctx,
Tok: tok, Tok: tok,
ExpectedIDTokenNonce: expectedIDTokenNonce, ExpectedIDTokenNonce: expectedIDTokenNonce,
RequireIDToken: requireIDToken, RequireIDToken: requireIDToken,