When upstream OIDC refresh fails inconclusively, retry a few times
This commit is contained in:
parent
78bdb1928a
commit
3301a62053
@ -8,10 +8,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/x/errorsx"
|
||||
"golang.org/x/oauth2"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
"k8s.io/client-go/util/retry"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
@ -39,6 +42,18 @@ var (
|
||||
func NewHandler(
|
||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
) http.Handler {
|
||||
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
|
||||
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
|
||||
upstreamRefreshRetryOnErrorFactor := 4.0
|
||||
|
||||
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
|
||||
}
|
||||
|
||||
func newHandler(
|
||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
upstreamRefreshRetryOnErrorFactor float64,
|
||||
) http.Handler {
|
||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
session := psession.NewPinnipedSession()
|
||||
@ -55,7 +70,7 @@ func NewHandler(
|
||||
// The session, requested scopes, and requested audience from the original authorize request was retrieved
|
||||
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
|
||||
// have already been granted on the accessRequest.
|
||||
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
|
||||
err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
|
||||
if err != nil {
|
||||
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
|
||||
oauthHelper.WriteAccessError(w, accessRequest, err)
|
||||
@ -76,7 +91,12 @@ func NewHandler(
|
||||
})
|
||||
}
|
||||
|
||||
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
func upstreamRefresh(
|
||||
ctx context.Context,
|
||||
accessRequest fosite.AccessRequester,
|
||||
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||
retryOnErrorFactor float64,
|
||||
) error {
|
||||
session := accessRequest.GetSession().(*psession.PinnipedSession)
|
||||
|
||||
customSessionData := session.Custom
|
||||
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
||||
|
||||
switch customSessionData.ProviderType {
|
||||
case psession.ProviderTypeOIDC:
|
||||
return upstreamOIDCRefresh(ctx, session, providerCache)
|
||||
return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
|
||||
case psession.ProviderTypeLDAP:
|
||||
return upstreamLDAPRefresh(ctx, providerCache, session)
|
||||
case psession.ProviderTypeActiveDirectory:
|
||||
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
||||
}
|
||||
}
|
||||
|
||||
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
func upstreamOIDCRefresh(
|
||||
ctx context.Context,
|
||||
session *psession.PinnipedSession,
|
||||
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||
retryOnErrorFactor float64,
|
||||
) error {
|
||||
s := session.Custom
|
||||
if s.OIDC == nil {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
|
||||
var tokens *oauth2.Token
|
||||
if refreshTokenStored {
|
||||
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
||||
tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||
"Upstream refresh failed.",
|
||||
@ -187,6 +212,44 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
return nil
|
||||
}
|
||||
|
||||
func performUpstreamOIDCRefreshWithRetriesOnError(
|
||||
ctx context.Context,
|
||||
p provider.UpstreamOIDCIdentityProviderI,
|
||||
s *psession.CustomSessionData,
|
||||
retryOnErrorFactor float64,
|
||||
) (*oauth2.Token, error) {
|
||||
var tokens *oauth2.Token
|
||||
|
||||
// For the default retryOnErrorFactor of 4.0 this backoff means...
|
||||
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
|
||||
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
|
||||
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
|
||||
|
||||
isRetryableError := func(err error) bool {
|
||||
plog.DebugErr("upstream refresh request failed in retry loop", err,
|
||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||
if ctx.Err() != nil {
|
||||
return false // Stop retrying if the context was closed (cancelled or timed out).
|
||||
}
|
||||
retrieveError := &oauth2.RetrieveError{}
|
||||
if errors.As(err, &retrieveError) {
|
||||
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
|
||||
}
|
||||
return true // Retry any other errors, e.g. connection errors.
|
||||
}
|
||||
|
||||
performRefreshOnce := func() error {
|
||||
var err error
|
||||
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
||||
return err
|
||||
}
|
||||
|
||||
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
|
||||
|
||||
// If all retries failed, then err will hold the error of the final failed retry.
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
|
||||
s := session.Custom
|
||||
|
||||
|
@ -218,6 +218,7 @@ var (
|
||||
)
|
||||
|
||||
type expectedUpstreamRefresh struct {
|
||||
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
|
||||
performedByUpstreamName string
|
||||
args *oidctestutil.PerformRefreshArgs
|
||||
}
|
||||
@ -1733,7 +1734,7 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when the upstream refresh fails during the refresh request",
|
||||
name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
@ -1743,7 +1744,65 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
|
||||
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
|
||||
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
"error": "error",
|
||||
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||
}
|
||||
`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
|
||||
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
"error": "error",
|
||||
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||
}
|
||||
`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
@ -2670,7 +2729,8 @@ func TestRefreshGrant(t *testing.T) {
|
||||
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
|
||||
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
|
||||
test.idps.RequireExactlyOneCallToPerformRefresh(t,
|
||||
test.idps.RequireExactlyNCallsToPerformRefresh(t,
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.args,
|
||||
)
|
||||
@ -2796,7 +2856,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
|
||||
test.modifyStorage(t, oauthStore, authCode)
|
||||
}
|
||||
|
||||
subject = NewHandler(idps, oauthHelper)
|
||||
// Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
|
||||
upstreamRefreshRetryOnErrorFactor := 1.0
|
||||
|
||||
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
|
||||
|
||||
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
|
||||
expectedNumberOfIDSessionsStored := 0
|
||||
|
@ -472,46 +472,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
|
||||
)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
|
||||
t *testing.T,
|
||||
expectedNumberOfCalls int,
|
||||
expectedPerformedByUpstreamName string,
|
||||
expectedArgs *PerformRefreshArgs,
|
||||
) {
|
||||
t.Helper()
|
||||
var actualArgs *PerformRefreshArgs
|
||||
var actualNameOfUpstreamWhichMadeCall string
|
||||
actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
|
||||
actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
|
||||
actualCallCountAcrossAllUpstreams := 0
|
||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
|
||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
||||
actualArgs = upstreamOIDC.performRefreshArgs[0]
|
||||
}
|
||||
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
|
||||
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
|
||||
}
|
||||
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
|
||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name
|
||||
actualArgs = upstreamLDAP.performRefreshArgs[0]
|
||||
}
|
||||
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
|
||||
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
|
||||
}
|
||||
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamAD.performRefreshCallCount
|
||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name
|
||||
actualArgs = upstreamAD.performRefreshArgs[0]
|
||||
}
|
||||
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
|
||||
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
|
||||
}
|
||||
require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
|
||||
"should have been exactly one call to PerformRefresh() by all upstreams")
|
||||
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
|
||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||
"PerformRefresh() was called on the wrong upstream at least once")
|
||||
}
|
||||
for _, actualArgs := range actualArgsOfAllCalls {
|
||||
require.Equal(t, expectedArgs, actualArgs,
|
||||
"PerformRefresh() was called with the wrong arguments at least once")
|
||||
}
|
||||
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
|
||||
"should have been exactly one call to PerformRefresh() by all upstreams",
|
||||
)
|
||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||
"PerformRefresh() was called on the wrong upstream",
|
||||
)
|
||||
require.Equal(t, expectedArgs, actualArgs)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user