From 3301a62053a33da000d0af7996276e88b993521c Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 19 Jan 2022 12:23:11 -0800 Subject: [PATCH] When upstream OIDC refresh fails inconclusively, retry a few times --- internal/oidc/token/token_handler.go | 73 +++++++++++++++++-- internal/oidc/token/token_handler_test.go | 71 +++++++++++++++++- .../testutil/oidctestutil/oidctestutil.go | 42 +++++------ 3 files changed, 155 insertions(+), 31 deletions(-) diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 93e1ac8f..40a8061d 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -8,10 +8,13 @@ import ( "context" "errors" "net/http" + "time" "github.com/ory/fosite" "github.com/ory/x/errorsx" "golang.org/x/oauth2" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/util/retry" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc" @@ -39,6 +42,18 @@ var ( func NewHandler( idpLister oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider, +) http.Handler { + // Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor. + // This only exists as a parameter so that unit tests can override it to avoid running slowly. + upstreamRefreshRetryOnErrorFactor := 4.0 + + return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor) +} + +func newHandler( + idpLister oidc.UpstreamIdentityProvidersLister, + oauthHelper fosite.OAuth2Provider, + upstreamRefreshRetryOnErrorFactor float64, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { session := psession.NewPinnipedSession() @@ -55,7 +70,7 @@ func NewHandler( // The session, requested scopes, and requested audience from the original authorize request was retrieved // from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may // have already been granted on the accessRequest. - err = upstreamRefresh(r.Context(), accessRequest, idpLister) + err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor) if err != nil { plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...) oauthHelper.WriteAccessError(w, accessRequest, err) @@ -76,7 +91,12 @@ func NewHandler( }) } -func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { +func upstreamRefresh( + ctx context.Context, + accessRequest fosite.AccessRequester, + providerCache oidc.UpstreamIdentityProvidersLister, + retryOnErrorFactor float64, +) error { session := accessRequest.GetSession().(*psession.PinnipedSession) customSessionData := session.Custom @@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, switch customSessionData.ProviderType { case psession.ProviderTypeOIDC: - return upstreamOIDCRefresh(ctx, session, providerCache) + return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor) case psession.ProviderTypeLDAP: return upstreamLDAPRefresh(ctx, providerCache, session) case psession.ProviderTypeActiveDirectory: @@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, } } -func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error { +func upstreamOIDCRefresh( + ctx context.Context, + session *psession.PinnipedSession, + providerCache oidc.UpstreamIdentityProvidersLister, + retryOnErrorFactor float64, +) error { s := session.Custom if s.OIDC == nil { return errorsx.WithStack(errMissingUpstreamSessionInternalError) @@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, var tokens *oauth2.Token if refreshTokenStored { - tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) + tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor) if err != nil { return errorsx.WithStack(errUpstreamRefreshError.WithHint( "Upstream refresh failed.", @@ -187,6 +212,44 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, return nil } +func performUpstreamOIDCRefreshWithRetriesOnError( + ctx context.Context, + p provider.UpstreamOIDCIdentityProviderI, + s *psession.CustomSessionData, + retryOnErrorFactor float64, +) (*oauth2.Token, error) { + var tokens *oauth2.Token + + // For the default retryOnErrorFactor of 4.0 this backoff means... + // Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s. + // Give up after a total of 6 tries over ~17s if they all resulted in errors. + backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor} + + isRetryableError := func(err error) bool { + plog.DebugErr("upstream refresh request failed in retry loop", err, + "providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID) + if ctx.Err() != nil { + return false // Stop retrying if the context was closed (cancelled or timed out). + } + retrieveError := &oauth2.RetrieveError{} + if errors.As(err, &retrieveError) { + return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying. + } + return true // Retry any other errors, e.g. connection errors. + } + + performRefreshOnce := func() error { + var err error + tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) + return err + } + + err := retry.OnError(backoff, isRetryableError, performRefreshOnce) + + // If all retries failed, then err will hold the error of the final failed retry. + return tokens, err +} + func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error { s := session.Custom diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index d9b4cb68..d771bcab 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -218,6 +218,7 @@ var ( ) type expectedUpstreamRefresh struct { + numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt performedByUpstreamName string args *oidctestutil.PerformRefreshArgs } @@ -1733,7 +1734,7 @@ func TestRefreshGrant(t *testing.T) { }, }, { - name: "when the upstream refresh fails during the refresh request", + name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), authcodeExchange: authcodeExchangeInputs{ @@ -1743,7 +1744,65 @@ func TestRefreshGrant(t *testing.T) { }, refreshRequest: refreshRequestInputs{ want: tokenEndpointResponseExpectedValues{ - wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), + wantUpstreamRefreshCall: &expectedUpstreamRefresh{ + numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries + performedByUpstreamName: oidcUpstreamName, + args: &oidctestutil.PerformRefreshArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + RefreshToken: oidcUpstreamInitialRefreshToken, + }, + }, + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed." + } + `), + }, + }, + }, + { + name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). + WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamRefreshCall: &expectedUpstreamRefresh{ + numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries + performedByUpstreamName: oidcUpstreamName, + args: &oidctestutil.PerformRefreshArgs{ + Ctx: nil, // this will be filled in with the actual request context by the test below + RefreshToken: oidcUpstreamInitialRefreshToken, + }, + }, + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed." + } + `), + }, + }, + }, + { + name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). + WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status wantStatus: http.StatusUnauthorized, wantErrorResponseBody: here.Doc(` { @@ -2670,7 +2729,8 @@ func TestRefreshGrant(t *testing.T) { // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. if test.refreshRequest.want.wantUpstreamRefreshCall != nil { test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext - test.idps.RequireExactlyOneCallToPerformRefresh(t, + test.idps.RequireExactlyNCallsToPerformRefresh(t, + test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamRefreshCall.args, ) @@ -2796,7 +2856,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p test.modifyStorage(t, oauthStore, authCode) } - subject = NewHandler(idps, oauthHelper) + // Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors. + upstreamRefreshRetryOnErrorFactor := 1.0 + + subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor) authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid") expectedNumberOfIDSessionsStored := 0 diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 92cc582c..3f1a5752 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -472,46 +472,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV ) } -func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( +func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh( t *testing.T, + expectedNumberOfCalls int, expectedPerformedByUpstreamName string, expectedArgs *PerformRefreshArgs, ) { t.Helper() - var actualArgs *PerformRefreshArgs - var actualNameOfUpstreamWhichMadeCall string + actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0) + actualNamesOfUpstreamWhichMadeCalls := make([]string, 0) actualCallCountAcrossAllUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount actualCallCountAcrossAllUpstreams += callCountOnThisUpstream - if callCountOnThisUpstream == 1 { - actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name - actualArgs = upstreamOIDC.performRefreshArgs[0] - } + actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name) + actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0]) } for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount actualCallCountAcrossAllUpstreams += callCountOnThisUpstream - if callCountOnThisUpstream == 1 { - actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name - actualArgs = upstreamLDAP.performRefreshArgs[0] - } + actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name) + actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0]) } for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders { callCountOnThisUpstream := upstreamAD.performRefreshCallCount actualCallCountAcrossAllUpstreams += callCountOnThisUpstream - if callCountOnThisUpstream == 1 { - actualNameOfUpstreamWhichMadeCall = upstreamAD.Name - actualArgs = upstreamAD.performRefreshArgs[0] - } + actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name) + actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0]) + } + require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams, + "should have been exactly one call to PerformRefresh() by all upstreams") + for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls { + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "PerformRefresh() was called on the wrong upstream at least once") + } + for _, actualArgs := range actualArgsOfAllCalls { + require.Equal(t, expectedArgs, actualArgs, + "PerformRefresh() was called with the wrong arguments at least once") } - require.Equal(t, 1, actualCallCountAcrossAllUpstreams, - "should have been exactly one call to PerformRefresh() by all upstreams", - ) - require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, - "PerformRefresh() was called on the wrong upstream", - ) - require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {