When upstream OIDC refresh fails inconclusively, retry a few times

This commit is contained in:
Ryan Richard 2022-01-19 12:23:11 -08:00
parent 78bdb1928a
commit 3301a62053
3 changed files with 155 additions and 31 deletions

View File

@ -8,10 +8,13 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"time"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/x/errorsx" "github.com/ory/x/errorsx"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
@ -39,6 +42,18 @@ var (
func NewHandler( func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister, idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
) http.Handler {
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
upstreamRefreshRetryOnErrorFactor := 4.0
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
}
func newHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
upstreamRefreshRetryOnErrorFactor float64,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
session := psession.NewPinnipedSession() session := psession.NewPinnipedSession()
@ -55,7 +70,7 @@ func NewHandler(
// The session, requested scopes, and requested audience from the original authorize request was retrieved // The session, requested scopes, and requested audience from the original authorize request was retrieved
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may // from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
// have already been granted on the accessRequest. // have already been granted on the accessRequest.
err = upstreamRefresh(r.Context(), accessRequest, idpLister) err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
if err != nil { if err != nil {
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...) plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err) oauthHelper.WriteAccessError(w, accessRequest, err)
@ -76,7 +91,12 @@ func NewHandler(
}) })
} }
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamRefresh(
ctx context.Context,
accessRequest fosite.AccessRequester,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
session := accessRequest.GetSession().(*psession.PinnipedSession) session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom customSessionData := session.Custom
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
switch customSessionData.ProviderType { switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC: case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, session, providerCache) return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
case psession.ProviderTypeLDAP: case psession.ProviderTypeLDAP:
return upstreamLDAPRefresh(ctx, providerCache, session) return upstreamLDAPRefresh(ctx, providerCache, session)
case psession.ProviderTypeActiveDirectory: case psession.ProviderTypeActiveDirectory:
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
} }
} }
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamOIDCRefresh(
ctx context.Context,
session *psession.PinnipedSession,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
s := session.Custom s := session.Custom
if s.OIDC == nil { if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
var tokens *oauth2.Token var tokens *oauth2.Token
if refreshTokenStored { if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint( return errorsx.WithStack(errUpstreamRefreshError.WithHint(
"Upstream refresh failed.", "Upstream refresh failed.",
@ -187,6 +212,44 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
return nil return nil
} }
func performUpstreamOIDCRefreshWithRetriesOnError(
ctx context.Context,
p provider.UpstreamOIDCIdentityProviderI,
s *psession.CustomSessionData,
retryOnErrorFactor float64,
) (*oauth2.Token, error) {
var tokens *oauth2.Token
// For the default retryOnErrorFactor of 4.0 this backoff means...
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
isRetryableError := func(err error) bool {
plog.DebugErr("upstream refresh request failed in retry loop", err,
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
if ctx.Err() != nil {
return false // Stop retrying if the context was closed (cancelled or timed out).
}
retrieveError := &oauth2.RetrieveError{}
if errors.As(err, &retrieveError) {
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
}
return true // Retry any other errors, e.g. connection errors.
}
performRefreshOnce := func() error {
var err error
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
return err
}
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
// If all retries failed, then err will hold the error of the final failed retry.
return tokens, err
}
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error { func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
s := session.Custom s := session.Custom

View File

@ -218,6 +218,7 @@ var (
) )
type expectedUpstreamRefresh struct { type expectedUpstreamRefresh struct {
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
performedByUpstreamName string performedByUpstreamName string
args *oidctestutil.PerformRefreshArgs args *oidctestutil.PerformRefreshArgs
} }
@ -1733,7 +1734,7 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
{ {
name: "when the upstream refresh fails during the refresh request", name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
@ -1743,7 +1744,65 @@ func TestRefreshGrant(t *testing.T) {
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -2670,7 +2729,8 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
if test.refreshRequest.want.wantUpstreamRefreshCall != nil { if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
test.idps.RequireExactlyOneCallToPerformRefresh(t, test.idps.RequireExactlyNCallsToPerformRefresh(t,
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamRefreshCall.args, test.refreshRequest.want.wantUpstreamRefreshCall.args,
) )
@ -2796,7 +2856,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
test.modifyStorage(t, oauthStore, authCode) test.modifyStorage(t, oauthStore, authCode)
} }
subject = NewHandler(idps, oauthHelper) // Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
upstreamRefreshRetryOnErrorFactor := 1.0
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid") authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
expectedNumberOfIDSessionsStored := 0 expectedNumberOfIDSessionsStored := 0

View File

@ -472,46 +472,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
) )
} }
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
t *testing.T, t *testing.T,
expectedNumberOfCalls int,
expectedPerformedByUpstreamName string, expectedPerformedByUpstreamName string,
expectedArgs *PerformRefreshArgs, expectedArgs *PerformRefreshArgs,
) { ) {
t.Helper() t.Helper()
var actualArgs *PerformRefreshArgs actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
var actualNameOfUpstreamWhichMadeCall string actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
actualCallCountAcrossAllUpstreams := 0 actualCallCountAcrossAllUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
actualArgs = upstreamOIDC.performRefreshArgs[0]
}
} }
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
actualArgs = upstreamLDAP.performRefreshArgs[0]
}
} }
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders { for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
callCountOnThisUpstream := upstreamAD.performRefreshCallCount callCountOnThisUpstream := upstreamAD.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
actualArgs = upstreamAD.performRefreshArgs[0] }
} require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams")
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream at least once")
}
for _, actualArgs := range actualArgsOfAllCalls {
require.Equal(t, expectedArgs, actualArgs,
"PerformRefresh() was called with the wrong arguments at least once")
} }
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream",
)
require.Equal(t, expectedArgs, actualArgs)
} }
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) { func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {