Compare commits

...

3 Commits

Author SHA1 Message Date
Ryan Richard
7f9867598c Merge branch 'main' into upstream-oidc-refresh-retries 2022-01-20 13:14:06 -08:00
Ryan Richard
ba83c12f93 Add a timeout to the upstream OIDC refresh calls
Unfortunately this means that we lose test coverage on passing the
request's context through to the mocked PerformRefresh function. We
used to be able to assert equality on the request's context because it
was passed through unchanged.
2022-01-20 13:11:05 -08:00
Ryan Richard
3301a62053 When upstream OIDC refresh fails inconclusively, retry a few times 2022-01-19 12:23:11 -08:00
3 changed files with 158 additions and 46 deletions

View File

@ -8,10 +8,13 @@ import (
"context"
"errors"
"net/http"
"time"
"github.com/ory/fosite"
"github.com/ory/x/errorsx"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
@ -39,6 +42,18 @@ var (
func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
) http.Handler {
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
upstreamRefreshRetryOnErrorFactor := 4.0
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
}
func newHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
upstreamRefreshRetryOnErrorFactor float64,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
session := psession.NewPinnipedSession()
@ -55,7 +70,7 @@ func NewHandler(
// The session, requested scopes, and requested audience from the original authorize request was retrieved
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
// have already been granted on the accessRequest.
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
if err != nil {
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err)
@ -76,7 +91,12 @@ func NewHandler(
})
}
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
func upstreamRefresh(
ctx context.Context,
accessRequest fosite.AccessRequester,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, session, providerCache)
return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
case psession.ProviderTypeLDAP:
return upstreamLDAPRefresh(ctx, providerCache, session)
case psession.ProviderTypeActiveDirectory:
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
}
}
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
func upstreamOIDCRefresh(
ctx context.Context,
session *psession.PinnipedSession,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
s := session.Custom
if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
var tokens *oauth2.Token
if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
"Upstream refresh failed.",
@ -189,6 +214,47 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
return nil
}
func performUpstreamOIDCRefreshWithRetriesOnError(
ctx context.Context,
p provider.UpstreamOIDCIdentityProviderI,
s *psession.CustomSessionData,
retryOnErrorFactor float64,
) (*oauth2.Token, error) {
var tokens *oauth2.Token
// For the default retryOnErrorFactor of 4.0 this backoff means...
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
isRetryableError := func(err error) bool {
plog.DebugErr("upstream refresh request failed in retry loop", err,
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
if ctx.Err() != nil {
return false // Stop retrying if the context was closed (cancelled or timed out).
}
retrieveError := &oauth2.RetrieveError{}
if errors.As(err, &retrieveError) {
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
}
return true // Retry any other errors, e.g. connection errors.
}
performRefreshOnce := func() error {
var err error
// Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects.
timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken)
return err
}
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
// If all retries failed, then err will hold the error of the final failed retry.
return tokens, err
}
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
s := session.Custom

View File

@ -225,6 +225,7 @@ var (
)
type expectedUpstreamRefresh struct {
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
performedByUpstreamName string
args *oidctestutil.PerformRefreshArgs
}
@ -956,7 +957,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken,
},
}
@ -966,7 +966,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{
performedByUpstreamName: ldapUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil,
DN: ldapUpstreamDN,
ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername,
@ -978,7 +977,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{
performedByUpstreamName: activeDirectoryUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
Ctx: nil,
DN: activeDirectoryUpstreamDN,
ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername,
@ -990,7 +988,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamValidateTokens{
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: expectedTokens,
ExpectedIDTokenNonce: "", // always expect empty string
RequireUserInfo: false,
@ -1155,7 +1152,6 @@ func TestRefreshGrant(t *testing.T) {
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
oidcUpstreamName,
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
ExpectedIDTokenNonce: "", // always expect empty string
RequireIDToken: false,
@ -1794,7 +1790,7 @@ func TestRefreshGrant(t *testing.T) {
},
},
{
name: "when the upstream refresh fails during the refresh request",
name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
authcodeExchange: authcodeExchangeInputs{
@ -1804,7 +1800,63 @@ func TestRefreshGrant(t *testing.T) {
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
@ -2715,9 +2767,8 @@ func TestRefreshGrant(t *testing.T) {
if test.modifyRefreshTokenStorage != nil {
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
}
reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context")
req := httptest.NewRequest("POST", "/path/shouldn't/matter",
happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext)
happyRefreshRequestBody(firstRefreshToken).ReadCloser())
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if test.refreshRequest.modifyTokenRequest != nil {
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
@ -2730,8 +2781,8 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
test.idps.RequireExactlyOneCallToPerformRefresh(t,
test.idps.RequireExactlyNCallsToPerformRefresh(t,
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamRefreshCall.args,
)
@ -2742,7 +2793,6 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the
// new ID token that was returned by the upstream refresh.
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext
test.idps.RequireExactlyOneCallToValidateToken(t,
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,
@ -2857,7 +2907,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
test.modifyStorage(t, oauthStore, authCode)
}
subject = NewHandler(idps, oauthHelper)
// Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
upstreamRefreshRetryOnErrorFactor := 1.0
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
expectedNumberOfIDSessionsStored := 0

View File

@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
// PerformRefreshArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
type PerformRefreshArgs struct {
Ctx context.Context
RefreshToken string
DN string
ExpectedUsername string
@ -79,7 +78,6 @@ type RevokeTokenArgs struct {
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
type ValidateTokenAndMergeWithUserInfoArgs struct {
Ctx context.Context
Tok *oauth2.Token
ExpectedIDTokenNonce nonce.Nonce
RequireIDToken bool
@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
return u.URL
}
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
}
u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
DN: storedRefreshAttributes.DN,
ExpectedUsername: storedRefreshAttributes.Username,
ExpectedSubject: storedRefreshAttributes.Subject,
@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
}
u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
RefreshToken: refreshToken,
})
return u.PerformRefreshFunc(ctx, refreshToken)
@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx
}
u.validateTokenAndMergeWithUserInfoCallCount++
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
Ctx: ctx,
Tok: tok,
ExpectedIDTokenNonce: expectedIDTokenNonce,
RequireIDToken: requireIDToken,
@ -472,46 +467,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
t *testing.T,
expectedNumberOfCalls int,
expectedPerformedByUpstreamName string,
expectedArgs *PerformRefreshArgs,
) {
t.Helper()
var actualArgs *PerformRefreshArgs
var actualNameOfUpstreamWhichMadeCall string
actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
actualCallCountAcrossAllUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.performRefreshArgs[0]
}
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
}
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name
actualArgs = upstreamLDAP.performRefreshArgs[0]
}
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
}
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
callCountOnThisUpstream := upstreamAD.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name
actualArgs = upstreamAD.performRefreshArgs[0]
}
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
}
require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams")
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream at least once")
}
for _, actualArgs := range actualArgsOfAllCalls {
require.Equal(t, expectedArgs, actualArgs,
"PerformRefresh() was called with the wrong arguments at least once")
}
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream",
)
require.Equal(t, expectedArgs, actualArgs)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {