When upstream OIDC refresh fails inconclusively, retry a few times

This commit is contained in:
Ryan Richard 2022-01-19 12:23:11 -08:00
parent 78bdb1928a
commit 3301a62053
3 changed files with 155 additions and 31 deletions

View File

@ -8,10 +8,13 @@ import (
"context"
"errors"
"net/http"
"time"
"github.com/ory/fosite"
"github.com/ory/x/errorsx"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
@ -39,6 +42,18 @@ var (
func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
) http.Handler {
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
upstreamRefreshRetryOnErrorFactor := 4.0
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
}
func newHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
upstreamRefreshRetryOnErrorFactor float64,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
session := psession.NewPinnipedSession()
@ -55,7 +70,7 @@ func NewHandler(
// The session, requested scopes, and requested audience from the original authorize request was retrieved
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
// have already been granted on the accessRequest.
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
if err != nil {
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err)
@ -76,7 +91,12 @@ func NewHandler(
})
}
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
func upstreamRefresh(
ctx context.Context,
accessRequest fosite.AccessRequester,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, session, providerCache)
return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
case psession.ProviderTypeLDAP:
return upstreamLDAPRefresh(ctx, providerCache, session)
case psession.ProviderTypeActiveDirectory:
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
}
}
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
func upstreamOIDCRefresh(
ctx context.Context,
session *psession.PinnipedSession,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
s := session.Custom
if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
var tokens *oauth2.Token
if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
"Upstream refresh failed.",
@ -187,6 +212,44 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
return nil
}
func performUpstreamOIDCRefreshWithRetriesOnError(
ctx context.Context,
p provider.UpstreamOIDCIdentityProviderI,
s *psession.CustomSessionData,
retryOnErrorFactor float64,
) (*oauth2.Token, error) {
var tokens *oauth2.Token
// For the default retryOnErrorFactor of 4.0 this backoff means...
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
isRetryableError := func(err error) bool {
plog.DebugErr("upstream refresh request failed in retry loop", err,
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
if ctx.Err() != nil {
return false // Stop retrying if the context was closed (cancelled or timed out).
}
retrieveError := &oauth2.RetrieveError{}
if errors.As(err, &retrieveError) {
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
}
return true // Retry any other errors, e.g. connection errors.
}
performRefreshOnce := func() error {
var err error
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
return err
}
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
// If all retries failed, then err will hold the error of the final failed retry.
return tokens, err
}
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
s := session.Custom

View File

@ -218,6 +218,7 @@ var (
)
type expectedUpstreamRefresh struct {
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
performedByUpstreamName string
args *oidctestutil.PerformRefreshArgs
}
@ -1733,7 +1734,7 @@ func TestRefreshGrant(t *testing.T) {
},
},
{
name: "when the upstream refresh fails during the refresh request",
name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
authcodeExchange: authcodeExchangeInputs{
@ -1743,7 +1744,65 @@ func TestRefreshGrant(t *testing.T) {
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
@ -2670,7 +2729,8 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
test.idps.RequireExactlyOneCallToPerformRefresh(t,
test.idps.RequireExactlyNCallsToPerformRefresh(t,
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamRefreshCall.args,
)
@ -2796,7 +2856,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
test.modifyStorage(t, oauthStore, authCode)
}
subject = NewHandler(idps, oauthHelper)
// Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
upstreamRefreshRetryOnErrorFactor := 1.0
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
expectedNumberOfIDSessionsStored := 0

View File

@ -472,46 +472,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
t *testing.T,
expectedNumberOfCalls int,
expectedPerformedByUpstreamName string,
expectedArgs *PerformRefreshArgs,
) {
t.Helper()
var actualArgs *PerformRefreshArgs
var actualNameOfUpstreamWhichMadeCall string
actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
actualCallCountAcrossAllUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.performRefreshArgs[0]
}
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
}
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name
actualArgs = upstreamLDAP.performRefreshArgs[0]
}
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
}
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
callCountOnThisUpstream := upstreamAD.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name
actualArgs = upstreamAD.performRefreshArgs[0]
}
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
}
require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams")
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream at least once")
}
for _, actualArgs := range actualArgsOfAllCalls {
require.Equal(t, expectedArgs, actualArgs,
"PerformRefresh() was called with the wrong arguments at least once")
}
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream",
)
require.Equal(t, expectedArgs, actualArgs)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {