Compare commits
3 Commits
main
...
upstream-o
Author | SHA1 | Date | |
---|---|---|---|
|
7f9867598c | ||
|
ba83c12f93 | ||
|
3301a62053 |
@ -8,10 +8,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/x/errorsx"
|
||||
"golang.org/x/oauth2"
|
||||
"k8s.io/apimachinery/pkg/util/wait"
|
||||
"k8s.io/client-go/util/retry"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
@ -39,6 +42,18 @@ var (
|
||||
func NewHandler(
|
||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
) http.Handler {
|
||||
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
|
||||
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
|
||||
upstreamRefreshRetryOnErrorFactor := 4.0
|
||||
|
||||
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
|
||||
}
|
||||
|
||||
func newHandler(
|
||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
upstreamRefreshRetryOnErrorFactor float64,
|
||||
) http.Handler {
|
||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
session := psession.NewPinnipedSession()
|
||||
@ -55,7 +70,7 @@ func NewHandler(
|
||||
// The session, requested scopes, and requested audience from the original authorize request was retrieved
|
||||
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
|
||||
// have already been granted on the accessRequest.
|
||||
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
|
||||
err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
|
||||
if err != nil {
|
||||
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
|
||||
oauthHelper.WriteAccessError(w, accessRequest, err)
|
||||
@ -76,7 +91,12 @@ func NewHandler(
|
||||
})
|
||||
}
|
||||
|
||||
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
func upstreamRefresh(
|
||||
ctx context.Context,
|
||||
accessRequest fosite.AccessRequester,
|
||||
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||
retryOnErrorFactor float64,
|
||||
) error {
|
||||
session := accessRequest.GetSession().(*psession.PinnipedSession)
|
||||
|
||||
customSessionData := session.Custom
|
||||
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
||||
|
||||
switch customSessionData.ProviderType {
|
||||
case psession.ProviderTypeOIDC:
|
||||
return upstreamOIDCRefresh(ctx, session, providerCache)
|
||||
return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
|
||||
case psession.ProviderTypeLDAP:
|
||||
return upstreamLDAPRefresh(ctx, providerCache, session)
|
||||
case psession.ProviderTypeActiveDirectory:
|
||||
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
||||
}
|
||||
}
|
||||
|
||||
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
func upstreamOIDCRefresh(
|
||||
ctx context.Context,
|
||||
session *psession.PinnipedSession,
|
||||
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||
retryOnErrorFactor float64,
|
||||
) error {
|
||||
s := session.Custom
|
||||
if s.OIDC == nil {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
|
||||
var tokens *oauth2.Token
|
||||
if refreshTokenStored {
|
||||
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
||||
tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||
"Upstream refresh failed.",
|
||||
@ -189,6 +214,47 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
return nil
|
||||
}
|
||||
|
||||
func performUpstreamOIDCRefreshWithRetriesOnError(
|
||||
ctx context.Context,
|
||||
p provider.UpstreamOIDCIdentityProviderI,
|
||||
s *psession.CustomSessionData,
|
||||
retryOnErrorFactor float64,
|
||||
) (*oauth2.Token, error) {
|
||||
var tokens *oauth2.Token
|
||||
|
||||
// For the default retryOnErrorFactor of 4.0 this backoff means...
|
||||
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
|
||||
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
|
||||
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
|
||||
|
||||
isRetryableError := func(err error) bool {
|
||||
plog.DebugErr("upstream refresh request failed in retry loop", err,
|
||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||
if ctx.Err() != nil {
|
||||
return false // Stop retrying if the context was closed (cancelled or timed out).
|
||||
}
|
||||
retrieveError := &oauth2.RetrieveError{}
|
||||
if errors.As(err, &retrieveError) {
|
||||
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
|
||||
}
|
||||
return true // Retry any other errors, e.g. connection errors.
|
||||
}
|
||||
|
||||
performRefreshOnce := func() error {
|
||||
var err error
|
||||
// Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects.
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||
defer cancel()
|
||||
tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken)
|
||||
return err
|
||||
}
|
||||
|
||||
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
|
||||
|
||||
// If all retries failed, then err will hold the error of the final failed retry.
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
|
||||
s := session.Custom
|
||||
|
||||
|
@ -225,6 +225,7 @@ var (
|
||||
)
|
||||
|
||||
type expectedUpstreamRefresh struct {
|
||||
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
|
||||
performedByUpstreamName string
|
||||
args *oidctestutil.PerformRefreshArgs
|
||||
}
|
||||
@ -956,7 +957,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamRefresh{
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
}
|
||||
@ -966,7 +966,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamRefresh{
|
||||
performedByUpstreamName: ldapUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil,
|
||||
DN: ldapUpstreamDN,
|
||||
ExpectedSubject: goodSubject,
|
||||
ExpectedUsername: goodUsername,
|
||||
@ -978,7 +977,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamRefresh{
|
||||
performedByUpstreamName: activeDirectoryUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
Ctx: nil,
|
||||
DN: activeDirectoryUpstreamDN,
|
||||
ExpectedSubject: goodSubject,
|
||||
ExpectedUsername: goodUsername,
|
||||
@ -990,7 +988,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
return &expectedUpstreamValidateTokens{
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
Tok: expectedTokens,
|
||||
ExpectedIDTokenNonce: "", // always expect empty string
|
||||
RequireUserInfo: false,
|
||||
@ -1155,7 +1152,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
|
||||
oidcUpstreamName,
|
||||
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
||||
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
|
||||
ExpectedIDTokenNonce: "", // always expect empty string
|
||||
RequireIDToken: false,
|
||||
@ -1794,7 +1790,7 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when the upstream refresh fails during the refresh request",
|
||||
name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
@ -1804,7 +1800,63 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
|
||||
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
|
||||
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
"error": "error",
|
||||
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||
}
|
||||
`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
|
||||
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
|
||||
performedByUpstreamName: oidcUpstreamName,
|
||||
args: &oidctestutil.PerformRefreshArgs{
|
||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||
},
|
||||
},
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
"error": "error",
|
||||
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||
}
|
||||
`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
@ -2715,9 +2767,8 @@ func TestRefreshGrant(t *testing.T) {
|
||||
if test.modifyRefreshTokenStorage != nil {
|
||||
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
|
||||
}
|
||||
reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context")
|
||||
req := httptest.NewRequest("POST", "/path/shouldn't/matter",
|
||||
happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext)
|
||||
happyRefreshRequestBody(firstRefreshToken).ReadCloser())
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
if test.refreshRequest.modifyTokenRequest != nil {
|
||||
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
|
||||
@ -2730,8 +2781,8 @@ func TestRefreshGrant(t *testing.T) {
|
||||
|
||||
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
|
||||
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
|
||||
test.idps.RequireExactlyOneCallToPerformRefresh(t,
|
||||
test.idps.RequireExactlyNCallsToPerformRefresh(t,
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
|
||||
test.refreshRequest.want.wantUpstreamRefreshCall.args,
|
||||
)
|
||||
@ -2742,7 +2793,6 @@ func TestRefreshGrant(t *testing.T) {
|
||||
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the
|
||||
// new ID token that was returned by the upstream refresh.
|
||||
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
|
||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext
|
||||
test.idps.RequireExactlyOneCallToValidateToken(t,
|
||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
|
||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,
|
||||
@ -2857,7 +2907,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
|
||||
test.modifyStorage(t, oauthStore, authCode)
|
||||
}
|
||||
|
||||
subject = NewHandler(idps, oauthHelper)
|
||||
// Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
|
||||
upstreamRefreshRetryOnErrorFactor := 1.0
|
||||
|
||||
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
|
||||
|
||||
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
|
||||
expectedNumberOfIDSessionsStored := 0
|
||||
|
@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
|
||||
// PerformRefreshArgs is used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
|
||||
type PerformRefreshArgs struct {
|
||||
Ctx context.Context
|
||||
RefreshToken string
|
||||
DN string
|
||||
ExpectedUsername string
|
||||
@ -79,7 +78,6 @@ type RevokeTokenArgs struct {
|
||||
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
|
||||
type ValidateTokenAndMergeWithUserInfoArgs struct {
|
||||
Ctx context.Context
|
||||
Tok *oauth2.Token
|
||||
ExpectedIDTokenNonce nonce.Nonce
|
||||
RequireIDToken bool
|
||||
@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
|
||||
return u.URL
|
||||
}
|
||||
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||
if u.performRefreshArgs == nil {
|
||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||
}
|
||||
u.performRefreshCallCount++
|
||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||
Ctx: ctx,
|
||||
DN: storedRefreshAttributes.DN,
|
||||
ExpectedUsername: storedRefreshAttributes.Username,
|
||||
ExpectedSubject: storedRefreshAttributes.Subject,
|
||||
@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
|
||||
}
|
||||
u.performRefreshCallCount++
|
||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||
Ctx: ctx,
|
||||
RefreshToken: refreshToken,
|
||||
})
|
||||
return u.PerformRefreshFunc(ctx, refreshToken)
|
||||
@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx
|
||||
}
|
||||
u.validateTokenAndMergeWithUserInfoCallCount++
|
||||
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
|
||||
Ctx: ctx,
|
||||
Tok: tok,
|
||||
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
||||
RequireIDToken: requireIDToken,
|
||||
@ -472,46 +467,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
|
||||
)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
|
||||
t *testing.T,
|
||||
expectedNumberOfCalls int,
|
||||
expectedPerformedByUpstreamName string,
|
||||
expectedArgs *PerformRefreshArgs,
|
||||
) {
|
||||
t.Helper()
|
||||
var actualArgs *PerformRefreshArgs
|
||||
var actualNameOfUpstreamWhichMadeCall string
|
||||
actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
|
||||
actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
|
||||
actualCallCountAcrossAllUpstreams := 0
|
||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
|
||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
||||
actualArgs = upstreamOIDC.performRefreshArgs[0]
|
||||
}
|
||||
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
|
||||
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
|
||||
}
|
||||
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
|
||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name
|
||||
actualArgs = upstreamLDAP.performRefreshArgs[0]
|
||||
}
|
||||
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
|
||||
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
|
||||
}
|
||||
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamAD.performRefreshCallCount
|
||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name
|
||||
actualArgs = upstreamAD.performRefreshArgs[0]
|
||||
}
|
||||
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
|
||||
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
|
||||
}
|
||||
require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
|
||||
"should have been exactly one call to PerformRefresh() by all upstreams")
|
||||
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
|
||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||
"PerformRefresh() was called on the wrong upstream at least once")
|
||||
}
|
||||
for _, actualArgs := range actualArgsOfAllCalls {
|
||||
require.Equal(t, expectedArgs, actualArgs,
|
||||
"PerformRefresh() was called with the wrong arguments at least once")
|
||||
}
|
||||
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
|
||||
"should have been exactly one call to PerformRefresh() by all upstreams",
|
||||
)
|
||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||
"PerformRefresh() was called on the wrong upstream",
|
||||
)
|
||||
require.Equal(t, expectedArgs, actualArgs)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user