Compare commits

...

3 Commits

Author SHA1 Message Date
Ryan Richard
7f9867598c Merge branch 'main' into upstream-oidc-refresh-retries 2022-01-20 13:14:06 -08:00
Ryan Richard
ba83c12f93 Add a timeout to the upstream OIDC refresh calls
Unfortunately this means that we lose test coverage on passing the
request's context through to the mocked PerformRefresh function. We
used to be able to assert equality on the request's context because it
was passed through unchanged.
2022-01-20 13:11:05 -08:00
Ryan Richard
3301a62053 When upstream OIDC refresh fails inconclusively, retry a few times 2022-01-19 12:23:11 -08:00
3 changed files with 158 additions and 46 deletions

View File

@ -8,10 +8,13 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"time"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/x/errorsx" "github.com/ory/x/errorsx"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/util/retry"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
@ -39,6 +42,18 @@ var (
func NewHandler( func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister, idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
) http.Handler {
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
upstreamRefreshRetryOnErrorFactor := 4.0
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
}
func newHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
upstreamRefreshRetryOnErrorFactor float64,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
session := psession.NewPinnipedSession() session := psession.NewPinnipedSession()
@ -55,7 +70,7 @@ func NewHandler(
// The session, requested scopes, and requested audience from the original authorize request was retrieved // The session, requested scopes, and requested audience from the original authorize request was retrieved
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may // from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
// have already been granted on the accessRequest. // have already been granted on the accessRequest.
err = upstreamRefresh(r.Context(), accessRequest, idpLister) err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
if err != nil { if err != nil {
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...) plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err) oauthHelper.WriteAccessError(w, accessRequest, err)
@ -76,7 +91,12 @@ func NewHandler(
}) })
} }
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamRefresh(
ctx context.Context,
accessRequest fosite.AccessRequester,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
session := accessRequest.GetSession().(*psession.PinnipedSession) session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom customSessionData := session.Custom
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
switch customSessionData.ProviderType { switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC: case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, session, providerCache) return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
case psession.ProviderTypeLDAP: case psession.ProviderTypeLDAP:
return upstreamLDAPRefresh(ctx, providerCache, session) return upstreamLDAPRefresh(ctx, providerCache, session)
case psession.ProviderTypeActiveDirectory: case psession.ProviderTypeActiveDirectory:
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
} }
} }
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamOIDCRefresh(
ctx context.Context,
session *psession.PinnipedSession,
providerCache oidc.UpstreamIdentityProvidersLister,
retryOnErrorFactor float64,
) error {
s := session.Custom s := session.Custom
if s.OIDC == nil { if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
var tokens *oauth2.Token var tokens *oauth2.Token
if refreshTokenStored { if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint( return errorsx.WithStack(errUpstreamRefreshError.WithHint(
"Upstream refresh failed.", "Upstream refresh failed.",
@ -189,6 +214,47 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
return nil return nil
} }
func performUpstreamOIDCRefreshWithRetriesOnError(
ctx context.Context,
p provider.UpstreamOIDCIdentityProviderI,
s *psession.CustomSessionData,
retryOnErrorFactor float64,
) (*oauth2.Token, error) {
var tokens *oauth2.Token
// For the default retryOnErrorFactor of 4.0 this backoff means...
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
isRetryableError := func(err error) bool {
plog.DebugErr("upstream refresh request failed in retry loop", err,
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
if ctx.Err() != nil {
return false // Stop retrying if the context was closed (cancelled or timed out).
}
retrieveError := &oauth2.RetrieveError{}
if errors.As(err, &retrieveError) {
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
}
return true // Retry any other errors, e.g. connection errors.
}
performRefreshOnce := func() error {
var err error
// Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects.
timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
defer cancel()
tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken)
return err
}
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
// If all retries failed, then err will hold the error of the final failed retry.
return tokens, err
}
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error { func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
s := session.Custom s := session.Custom

View File

@ -225,6 +225,7 @@ var (
) )
type expectedUpstreamRefresh struct { type expectedUpstreamRefresh struct {
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
performedByUpstreamName string performedByUpstreamName string
args *oidctestutil.PerformRefreshArgs args *oidctestutil.PerformRefreshArgs
} }
@ -956,7 +957,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{ return &expectedUpstreamRefresh{
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
RefreshToken: oidcUpstreamInitialRefreshToken, RefreshToken: oidcUpstreamInitialRefreshToken,
}, },
} }
@ -966,7 +966,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{ return &expectedUpstreamRefresh{
performedByUpstreamName: ldapUpstreamName, performedByUpstreamName: ldapUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil,
DN: ldapUpstreamDN, DN: ldapUpstreamDN,
ExpectedSubject: goodSubject, ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername, ExpectedUsername: goodUsername,
@ -978,7 +977,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamRefresh{ return &expectedUpstreamRefresh{
performedByUpstreamName: activeDirectoryUpstreamName, performedByUpstreamName: activeDirectoryUpstreamName,
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil,
DN: activeDirectoryUpstreamDN, DN: activeDirectoryUpstreamDN,
ExpectedSubject: goodSubject, ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername, ExpectedUsername: goodUsername,
@ -990,7 +988,6 @@ func TestRefreshGrant(t *testing.T) {
return &expectedUpstreamValidateTokens{ return &expectedUpstreamValidateTokens{
performedByUpstreamName: oidcUpstreamName, performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: expectedTokens, Tok: expectedTokens,
ExpectedIDTokenNonce: "", // always expect empty string ExpectedIDTokenNonce: "", // always expect empty string
RequireUserInfo: false, RequireUserInfo: false,
@ -1155,7 +1152,6 @@ func TestRefreshGrant(t *testing.T) {
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{ wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
oidcUpstreamName, oidcUpstreamName,
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{ &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
Ctx: nil, // this will be filled in with the actual request context by the test below
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
ExpectedIDTokenNonce: "", // always expect empty string ExpectedIDTokenNonce: "", // always expect empty string
RequireIDToken: false, RequireIDToken: false,
@ -1794,7 +1790,7 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
{ {
name: "when the upstream refresh fails during the refresh request", name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder(). idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()), WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
@ -1804,7 +1800,63 @@ func TestRefreshGrant(t *testing.T) {
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
performedByUpstreamName: oidcUpstreamName,
args: &oidctestutil.PerformRefreshArgs{
RefreshToken: oidcUpstreamInitialRefreshToken,
},
},
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -2715,9 +2767,8 @@ func TestRefreshGrant(t *testing.T) {
if test.modifyRefreshTokenStorage != nil { if test.modifyRefreshTokenStorage != nil {
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken) test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
} }
reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context")
req := httptest.NewRequest("POST", "/path/shouldn't/matter", req := httptest.NewRequest("POST", "/path/shouldn't/matter",
happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext) happyRefreshRequestBody(firstRefreshToken).ReadCloser())
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if test.refreshRequest.modifyTokenRequest != nil { if test.refreshRequest.modifyTokenRequest != nil {
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string)) test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
@ -2730,8 +2781,8 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh. // Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
if test.refreshRequest.want.wantUpstreamRefreshCall != nil { if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext test.idps.RequireExactlyNCallsToPerformRefresh(t,
test.idps.RequireExactlyOneCallToPerformRefresh(t, test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamRefreshCall.args, test.refreshRequest.want.wantUpstreamRefreshCall.args,
) )
@ -2742,7 +2793,6 @@ func TestRefreshGrant(t *testing.T) {
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the // Test that we did or did not make a call to the upstream OIDC provider interface to validate the
// new ID token that was returned by the upstream refresh. // new ID token that was returned by the upstream refresh.
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil { if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext
test.idps.RequireExactlyOneCallToValidateToken(t, test.idps.RequireExactlyOneCallToValidateToken(t,
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName, test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args, test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,
@ -2857,7 +2907,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
test.modifyStorage(t, oauthStore, authCode) test.modifyStorage(t, oauthStore, authCode)
} }
subject = NewHandler(idps, oauthHelper) // Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
upstreamRefreshRetryOnErrorFactor := 1.0
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid") authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
expectedNumberOfIDSessionsStored := 0 expectedNumberOfIDSessionsStored := 0

View File

@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
// PerformRefreshArgs is used to spy on calls to // PerformRefreshArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). // TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
type PerformRefreshArgs struct { type PerformRefreshArgs struct {
Ctx context.Context
RefreshToken string RefreshToken string
DN string DN string
ExpectedUsername string ExpectedUsername string
@ -79,7 +78,6 @@ type RevokeTokenArgs struct {
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to // ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc(). // TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
type ValidateTokenAndMergeWithUserInfoArgs struct { type ValidateTokenAndMergeWithUserInfoArgs struct {
Ctx context.Context
Tok *oauth2.Token Tok *oauth2.Token
ExpectedIDTokenNonce nonce.Nonce ExpectedIDTokenNonce nonce.Nonce
RequireIDToken bool RequireIDToken bool
@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
return u.URL return u.URL
} }
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error { func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
if u.performRefreshArgs == nil { if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0) u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
} }
u.performRefreshCallCount++ u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
DN: storedRefreshAttributes.DN, DN: storedRefreshAttributes.DN,
ExpectedUsername: storedRefreshAttributes.Username, ExpectedUsername: storedRefreshAttributes.Username,
ExpectedSubject: storedRefreshAttributes.Subject, ExpectedSubject: storedRefreshAttributes.Subject,
@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
} }
u.performRefreshCallCount++ u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
RefreshToken: refreshToken, RefreshToken: refreshToken,
}) })
return u.PerformRefreshFunc(ctx, refreshToken) return u.PerformRefreshFunc(ctx, refreshToken)
@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx
} }
u.validateTokenAndMergeWithUserInfoCallCount++ u.validateTokenAndMergeWithUserInfoCallCount++
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{ u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
Ctx: ctx,
Tok: tok, Tok: tok,
ExpectedIDTokenNonce: expectedIDTokenNonce, ExpectedIDTokenNonce: expectedIDTokenNonce,
RequireIDToken: requireIDToken, RequireIDToken: requireIDToken,
@ -472,46 +467,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
) )
} }
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
t *testing.T, t *testing.T,
expectedNumberOfCalls int,
expectedPerformedByUpstreamName string, expectedPerformedByUpstreamName string,
expectedArgs *PerformRefreshArgs, expectedArgs *PerformRefreshArgs,
) { ) {
t.Helper() t.Helper()
var actualArgs *PerformRefreshArgs actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
var actualNameOfUpstreamWhichMadeCall string actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
actualCallCountAcrossAllUpstreams := 0 actualCallCountAcrossAllUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
actualArgs = upstreamOIDC.performRefreshArgs[0]
}
} }
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
actualArgs = upstreamLDAP.performRefreshArgs[0]
}
} }
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders { for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
callCountOnThisUpstream := upstreamAD.performRefreshCallCount callCountOnThisUpstream := upstreamAD.performRefreshCallCount
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
actualArgs = upstreamAD.performRefreshArgs[0] }
} require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams")
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream at least once")
}
for _, actualArgs := range actualArgsOfAllCalls {
require.Equal(t, expectedArgs, actualArgs,
"PerformRefresh() was called with the wrong arguments at least once")
} }
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
"should have been exactly one call to PerformRefresh() by all upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong upstream",
)
require.Equal(t, expectedArgs, actualArgs)
} }
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) { func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {