Compare commits
3 Commits
main
...
upstream-o
Author | SHA1 | Date | |
---|---|---|---|
|
7f9867598c | ||
|
ba83c12f93 | ||
|
3301a62053 |
@ -8,10 +8,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/ory/fosite"
|
"github.com/ory/fosite"
|
||||||
"github.com/ory/x/errorsx"
|
"github.com/ory/x/errorsx"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"k8s.io/apimachinery/pkg/util/wait"
|
||||||
|
"k8s.io/client-go/util/retry"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/httputil/httperr"
|
"go.pinniped.dev/internal/httputil/httperr"
|
||||||
"go.pinniped.dev/internal/oidc"
|
"go.pinniped.dev/internal/oidc"
|
||||||
@ -39,6 +42,18 @@ var (
|
|||||||
func NewHandler(
|
func NewHandler(
|
||||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||||
oauthHelper fosite.OAuth2Provider,
|
oauthHelper fosite.OAuth2Provider,
|
||||||
|
) http.Handler {
|
||||||
|
// Each retry of a failed upstream refresh will multiply the previous sleep duration by this factor.
|
||||||
|
// This only exists as a parameter so that unit tests can override it to avoid running slowly.
|
||||||
|
upstreamRefreshRetryOnErrorFactor := 4.0
|
||||||
|
|
||||||
|
return newHandler(idpLister, oauthHelper, upstreamRefreshRetryOnErrorFactor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHandler(
|
||||||
|
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||||
|
oauthHelper fosite.OAuth2Provider,
|
||||||
|
upstreamRefreshRetryOnErrorFactor float64,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
session := psession.NewPinnipedSession()
|
session := psession.NewPinnipedSession()
|
||||||
@ -55,7 +70,7 @@ func NewHandler(
|
|||||||
// The session, requested scopes, and requested audience from the original authorize request was retrieved
|
// The session, requested scopes, and requested audience from the original authorize request was retrieved
|
||||||
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
|
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
|
||||||
// have already been granted on the accessRequest.
|
// have already been granted on the accessRequest.
|
||||||
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
|
err = upstreamRefresh(r.Context(), accessRequest, idpLister, upstreamRefreshRetryOnErrorFactor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
|
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
|
||||||
oauthHelper.WriteAccessError(w, accessRequest, err)
|
oauthHelper.WriteAccessError(w, accessRequest, err)
|
||||||
@ -76,7 +91,12 @@ func NewHandler(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
func upstreamRefresh(
|
||||||
|
ctx context.Context,
|
||||||
|
accessRequest fosite.AccessRequester,
|
||||||
|
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||||
|
retryOnErrorFactor float64,
|
||||||
|
) error {
|
||||||
session := accessRequest.GetSession().(*psession.PinnipedSession)
|
session := accessRequest.GetSession().(*psession.PinnipedSession)
|
||||||
|
|
||||||
customSessionData := session.Custom
|
customSessionData := session.Custom
|
||||||
@ -91,7 +111,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
|||||||
|
|
||||||
switch customSessionData.ProviderType {
|
switch customSessionData.ProviderType {
|
||||||
case psession.ProviderTypeOIDC:
|
case psession.ProviderTypeOIDC:
|
||||||
return upstreamOIDCRefresh(ctx, session, providerCache)
|
return upstreamOIDCRefresh(ctx, session, providerCache, retryOnErrorFactor)
|
||||||
case psession.ProviderTypeLDAP:
|
case psession.ProviderTypeLDAP:
|
||||||
return upstreamLDAPRefresh(ctx, providerCache, session)
|
return upstreamLDAPRefresh(ctx, providerCache, session)
|
||||||
case psession.ProviderTypeActiveDirectory:
|
case psession.ProviderTypeActiveDirectory:
|
||||||
@ -101,7 +121,12 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
func upstreamOIDCRefresh(
|
||||||
|
ctx context.Context,
|
||||||
|
session *psession.PinnipedSession,
|
||||||
|
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||||
|
retryOnErrorFactor float64,
|
||||||
|
) error {
|
||||||
s := session.Custom
|
s := session.Custom
|
||||||
if s.OIDC == nil {
|
if s.OIDC == nil {
|
||||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||||
@ -125,7 +150,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
|
|
||||||
var tokens *oauth2.Token
|
var tokens *oauth2.Token
|
||||||
if refreshTokenStored {
|
if refreshTokenStored {
|
||||||
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
tokens, err = performUpstreamOIDCRefreshWithRetriesOnError(ctx, p, s, retryOnErrorFactor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||||
"Upstream refresh failed.",
|
"Upstream refresh failed.",
|
||||||
@ -189,6 +214,47 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func performUpstreamOIDCRefreshWithRetriesOnError(
|
||||||
|
ctx context.Context,
|
||||||
|
p provider.UpstreamOIDCIdentityProviderI,
|
||||||
|
s *psession.CustomSessionData,
|
||||||
|
retryOnErrorFactor float64,
|
||||||
|
) (*oauth2.Token, error) {
|
||||||
|
var tokens *oauth2.Token
|
||||||
|
|
||||||
|
// For the default retryOnErrorFactor of 4.0 this backoff means...
|
||||||
|
// Try once, then retry upon error after sleeps of 50ms, 0.2s, 0.8s, 3.2s, and 12.8s.
|
||||||
|
// Give up after a total of 6 tries over ~17s if they all resulted in errors.
|
||||||
|
backoff := wait.Backoff{Steps: 6, Duration: 50 * time.Millisecond, Factor: retryOnErrorFactor}
|
||||||
|
|
||||||
|
isRetryableError := func(err error) bool {
|
||||||
|
plog.DebugErr("upstream refresh request failed in retry loop", err,
|
||||||
|
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return false // Stop retrying if the context was closed (cancelled or timed out).
|
||||||
|
}
|
||||||
|
retrieveError := &oauth2.RetrieveError{}
|
||||||
|
if errors.As(err, &retrieveError) {
|
||||||
|
return retrieveError.Response.StatusCode >= 500 // 5xx statuses are inconclusive and might be worth retrying.
|
||||||
|
}
|
||||||
|
return true // Retry any other errors, e.g. connection errors.
|
||||||
|
}
|
||||||
|
|
||||||
|
performRefreshOnce := func() error {
|
||||||
|
var err error
|
||||||
|
// Timeout to more likely have a chance to retry before a client gets tired of waiting and disconnects.
|
||||||
|
timeoutCtx, cancel := context.WithTimeout(ctx, 45*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
tokens, err = p.PerformRefresh(timeoutCtx, s.OIDC.UpstreamRefreshToken)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := retry.OnError(backoff, isRetryableError, performRefreshOnce)
|
||||||
|
|
||||||
|
// If all retries failed, then err will hold the error of the final failed retry.
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
|
||||||
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
|
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error {
|
||||||
s := session.Custom
|
s := session.Custom
|
||||||
|
|
||||||
|
@ -225,6 +225,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type expectedUpstreamRefresh struct {
|
type expectedUpstreamRefresh struct {
|
||||||
|
numberOfRetryAttempts int // number of expected retries, not including the original refresh attempt
|
||||||
performedByUpstreamName string
|
performedByUpstreamName string
|
||||||
args *oidctestutil.PerformRefreshArgs
|
args *oidctestutil.PerformRefreshArgs
|
||||||
}
|
}
|
||||||
@ -956,7 +957,6 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
return &expectedUpstreamRefresh{
|
return &expectedUpstreamRefresh{
|
||||||
performedByUpstreamName: oidcUpstreamName,
|
performedByUpstreamName: oidcUpstreamName,
|
||||||
args: &oidctestutil.PerformRefreshArgs{
|
args: &oidctestutil.PerformRefreshArgs{
|
||||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
|
||||||
RefreshToken: oidcUpstreamInitialRefreshToken,
|
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -966,7 +966,6 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
return &expectedUpstreamRefresh{
|
return &expectedUpstreamRefresh{
|
||||||
performedByUpstreamName: ldapUpstreamName,
|
performedByUpstreamName: ldapUpstreamName,
|
||||||
args: &oidctestutil.PerformRefreshArgs{
|
args: &oidctestutil.PerformRefreshArgs{
|
||||||
Ctx: nil,
|
|
||||||
DN: ldapUpstreamDN,
|
DN: ldapUpstreamDN,
|
||||||
ExpectedSubject: goodSubject,
|
ExpectedSubject: goodSubject,
|
||||||
ExpectedUsername: goodUsername,
|
ExpectedUsername: goodUsername,
|
||||||
@ -978,7 +977,6 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
return &expectedUpstreamRefresh{
|
return &expectedUpstreamRefresh{
|
||||||
performedByUpstreamName: activeDirectoryUpstreamName,
|
performedByUpstreamName: activeDirectoryUpstreamName,
|
||||||
args: &oidctestutil.PerformRefreshArgs{
|
args: &oidctestutil.PerformRefreshArgs{
|
||||||
Ctx: nil,
|
|
||||||
DN: activeDirectoryUpstreamDN,
|
DN: activeDirectoryUpstreamDN,
|
||||||
ExpectedSubject: goodSubject,
|
ExpectedSubject: goodSubject,
|
||||||
ExpectedUsername: goodUsername,
|
ExpectedUsername: goodUsername,
|
||||||
@ -990,7 +988,6 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
return &expectedUpstreamValidateTokens{
|
return &expectedUpstreamValidateTokens{
|
||||||
performedByUpstreamName: oidcUpstreamName,
|
performedByUpstreamName: oidcUpstreamName,
|
||||||
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
args: &oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
||||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
|
||||||
Tok: expectedTokens,
|
Tok: expectedTokens,
|
||||||
ExpectedIDTokenNonce: "", // always expect empty string
|
ExpectedIDTokenNonce: "", // always expect empty string
|
||||||
RequireUserInfo: false,
|
RequireUserInfo: false,
|
||||||
@ -1155,7 +1152,6 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
|
wantUpstreamOIDCValidateTokenCall: &expectedUpstreamValidateTokens{
|
||||||
oidcUpstreamName,
|
oidcUpstreamName,
|
||||||
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
&oidctestutil.ValidateTokenAndMergeWithUserInfoArgs{
|
||||||
Ctx: nil, // this will be filled in with the actual request context by the test below
|
|
||||||
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
|
Tok: &oauth2.Token{AccessToken: oidcUpstreamAccessToken}, // only the old access token
|
||||||
ExpectedIDTokenNonce: "", // always expect empty string
|
ExpectedIDTokenNonce: "", // always expect empty string
|
||||||
RequireIDToken: false,
|
RequireIDToken: false,
|
||||||
@ -1794,7 +1790,7 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "when the upstream refresh fails during the refresh request",
|
name: "when the upstream refresh fails with a generic error during the refresh request it retries the upstream refresh",
|
||||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||||
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
|
WithPerformRefreshError(errors.New("some upstream refresh error")).Build()),
|
||||||
authcodeExchange: authcodeExchangeInputs{
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
@ -1804,7 +1800,63 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
},
|
},
|
||||||
refreshRequest: refreshRequestInputs{
|
refreshRequest: refreshRequestInputs{
|
||||||
want: tokenEndpointResponseExpectedValues{
|
want: tokenEndpointResponseExpectedValues{
|
||||||
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
|
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
|
||||||
|
numberOfRetryAttempts: 5, // every attempt returns a generic error, so it should reach the maximum number of retries
|
||||||
|
performedByUpstreamName: oidcUpstreamName,
|
||||||
|
args: &oidctestutil.PerformRefreshArgs{
|
||||||
|
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
wantErrorResponseBody: here.Doc(`
|
||||||
|
{
|
||||||
|
"error": "error",
|
||||||
|
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when the upstream refresh fails with an http status 5xx error during the refresh request it retries the upstream refresh",
|
||||||
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||||
|
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusServiceUnavailable}}).Build()),
|
||||||
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
|
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
|
||||||
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
|
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
|
||||||
|
},
|
||||||
|
refreshRequest: refreshRequestInputs{
|
||||||
|
want: tokenEndpointResponseExpectedValues{
|
||||||
|
wantUpstreamRefreshCall: &expectedUpstreamRefresh{
|
||||||
|
numberOfRetryAttempts: 5, // every attempt returns a 5xx error, so it should reach the maximum number of retries
|
||||||
|
performedByUpstreamName: oidcUpstreamName,
|
||||||
|
args: &oidctestutil.PerformRefreshArgs{
|
||||||
|
RefreshToken: oidcUpstreamInitialRefreshToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
wantErrorResponseBody: here.Doc(`
|
||||||
|
{
|
||||||
|
"error": "error",
|
||||||
|
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "when the upstream refresh fails with an http status 4xx error during the refresh request it does not retry the upstream refresh",
|
||||||
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProviderBuilder().
|
||||||
|
WithPerformRefreshError(&oauth2.RetrieveError{Response: &http.Response{StatusCode: http.StatusForbidden}}).Build()),
|
||||||
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
|
customSessionData: initialUpstreamOIDCRefreshTokenCustomSessionData(),
|
||||||
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
|
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCRefreshTokenCustomSessionData()),
|
||||||
|
},
|
||||||
|
refreshRequest: refreshRequestInputs{
|
||||||
|
want: tokenEndpointResponseExpectedValues{
|
||||||
|
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), // no retries should happen after the original request returns a 4xx status
|
||||||
wantStatus: http.StatusUnauthorized,
|
wantStatus: http.StatusUnauthorized,
|
||||||
wantErrorResponseBody: here.Doc(`
|
wantErrorResponseBody: here.Doc(`
|
||||||
{
|
{
|
||||||
@ -2715,9 +2767,8 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
if test.modifyRefreshTokenStorage != nil {
|
if test.modifyRefreshTokenStorage != nil {
|
||||||
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
|
test.modifyRefreshTokenStorage(t, oauthStore, firstRefreshToken)
|
||||||
}
|
}
|
||||||
reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context")
|
|
||||||
req := httptest.NewRequest("POST", "/path/shouldn't/matter",
|
req := httptest.NewRequest("POST", "/path/shouldn't/matter",
|
||||||
happyRefreshRequestBody(firstRefreshToken).ReadCloser()).WithContext(reqContext)
|
happyRefreshRequestBody(firstRefreshToken).ReadCloser())
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
if test.refreshRequest.modifyTokenRequest != nil {
|
if test.refreshRequest.modifyTokenRequest != nil {
|
||||||
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
|
test.refreshRequest.modifyTokenRequest(req, firstRefreshToken, parsedAuthcodeExchangeResponseBody["access_token"].(string))
|
||||||
@ -2730,8 +2781,8 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
|
|
||||||
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
|
// Test that we did or did not make a call to the upstream OIDC provider interface to perform a token refresh.
|
||||||
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
|
if test.refreshRequest.want.wantUpstreamRefreshCall != nil {
|
||||||
test.refreshRequest.want.wantUpstreamRefreshCall.args.Ctx = reqContext
|
test.idps.RequireExactlyNCallsToPerformRefresh(t,
|
||||||
test.idps.RequireExactlyOneCallToPerformRefresh(t,
|
test.refreshRequest.want.wantUpstreamRefreshCall.numberOfRetryAttempts+1, // plus one for the original attempt
|
||||||
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
|
test.refreshRequest.want.wantUpstreamRefreshCall.performedByUpstreamName,
|
||||||
test.refreshRequest.want.wantUpstreamRefreshCall.args,
|
test.refreshRequest.want.wantUpstreamRefreshCall.args,
|
||||||
)
|
)
|
||||||
@ -2742,7 +2793,6 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the
|
// Test that we did or did not make a call to the upstream OIDC provider interface to validate the
|
||||||
// new ID token that was returned by the upstream refresh.
|
// new ID token that was returned by the upstream refresh.
|
||||||
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
|
if test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall != nil {
|
||||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args.Ctx = reqContext
|
|
||||||
test.idps.RequireExactlyOneCallToValidateToken(t,
|
test.idps.RequireExactlyOneCallToValidateToken(t,
|
||||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
|
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.performedByUpstreamName,
|
||||||
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,
|
test.refreshRequest.want.wantUpstreamOIDCValidateTokenCall.args,
|
||||||
@ -2857,7 +2907,10 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs, idps p
|
|||||||
test.modifyStorage(t, oauthStore, authCode)
|
test.modifyStorage(t, oauthStore, authCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
subject = NewHandler(idps, oauthHelper)
|
// Use a faster factor for this test to avoid the runtime penalty of exponential backoff on errors.
|
||||||
|
upstreamRefreshRetryOnErrorFactor := 1.0
|
||||||
|
|
||||||
|
subject = newHandler(idps, oauthHelper, upstreamRefreshRetryOnErrorFactor)
|
||||||
|
|
||||||
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
|
authorizeEndpointGrantedOpenIDScope := strings.Contains(authRequest.Form.Get("scope"), "openid")
|
||||||
expectedNumberOfIDSessionsStored := 0
|
expectedNumberOfIDSessionsStored := 0
|
||||||
|
@ -61,7 +61,6 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
|
|||||||
// PerformRefreshArgs is used to spy on calls to
|
// PerformRefreshArgs is used to spy on calls to
|
||||||
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
|
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
|
||||||
type PerformRefreshArgs struct {
|
type PerformRefreshArgs struct {
|
||||||
Ctx context.Context
|
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
DN string
|
DN string
|
||||||
ExpectedUsername string
|
ExpectedUsername string
|
||||||
@ -79,7 +78,6 @@ type RevokeTokenArgs struct {
|
|||||||
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
|
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
|
||||||
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
|
// TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc().
|
||||||
type ValidateTokenAndMergeWithUserInfoArgs struct {
|
type ValidateTokenAndMergeWithUserInfoArgs struct {
|
||||||
Ctx context.Context
|
|
||||||
Tok *oauth2.Token
|
Tok *oauth2.Token
|
||||||
ExpectedIDTokenNonce nonce.Nonce
|
ExpectedIDTokenNonce nonce.Nonce
|
||||||
RequireIDToken bool
|
RequireIDToken bool
|
||||||
@ -120,13 +118,12 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
|
|||||||
return u.URL
|
return u.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(_ context.Context, storedRefreshAttributes provider.StoredRefreshAttributes) error {
|
||||||
if u.performRefreshArgs == nil {
|
if u.performRefreshArgs == nil {
|
||||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||||
}
|
}
|
||||||
u.performRefreshCallCount++
|
u.performRefreshCallCount++
|
||||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||||
Ctx: ctx,
|
|
||||||
DN: storedRefreshAttributes.DN,
|
DN: storedRefreshAttributes.DN,
|
||||||
ExpectedUsername: storedRefreshAttributes.Username,
|
ExpectedUsername: storedRefreshAttributes.Username,
|
||||||
ExpectedSubject: storedRefreshAttributes.Subject,
|
ExpectedSubject: storedRefreshAttributes.Subject,
|
||||||
@ -286,7 +283,6 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
|
|||||||
}
|
}
|
||||||
u.performRefreshCallCount++
|
u.performRefreshCallCount++
|
||||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||||
Ctx: ctx,
|
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
})
|
})
|
||||||
return u.PerformRefreshFunc(ctx, refreshToken)
|
return u.PerformRefreshFunc(ctx, refreshToken)
|
||||||
@ -333,7 +329,6 @@ func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx
|
|||||||
}
|
}
|
||||||
u.validateTokenAndMergeWithUserInfoCallCount++
|
u.validateTokenAndMergeWithUserInfoCallCount++
|
||||||
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
|
u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{
|
||||||
Ctx: ctx,
|
|
||||||
Tok: tok,
|
Tok: tok,
|
||||||
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
||||||
RequireIDToken: requireIDToken,
|
RequireIDToken: requireIDToken,
|
||||||
@ -472,46 +467,44 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
|
func (b *UpstreamIDPListerBuilder) RequireExactlyNCallsToPerformRefresh(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
|
expectedNumberOfCalls int,
|
||||||
expectedPerformedByUpstreamName string,
|
expectedPerformedByUpstreamName string,
|
||||||
expectedArgs *PerformRefreshArgs,
|
expectedArgs *PerformRefreshArgs,
|
||||||
) {
|
) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
var actualArgs *PerformRefreshArgs
|
actualArgsOfAllCalls := make([]*PerformRefreshArgs, 0)
|
||||||
var actualNameOfUpstreamWhichMadeCall string
|
actualNamesOfUpstreamWhichMadeCalls := make([]string, 0)
|
||||||
actualCallCountAcrossAllUpstreams := 0
|
actualCallCountAcrossAllUpstreams := 0
|
||||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||||
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
|
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
|
||||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||||
if callCountOnThisUpstream == 1 {
|
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamOIDC.Name)
|
||||||
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamOIDC.performRefreshArgs[0])
|
||||||
actualArgs = upstreamOIDC.performRefreshArgs[0]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
|
for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders {
|
||||||
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
|
callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount
|
||||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||||
if callCountOnThisUpstream == 1 {
|
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamLDAP.Name)
|
||||||
actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name
|
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamLDAP.performRefreshArgs[0])
|
||||||
actualArgs = upstreamLDAP.performRefreshArgs[0]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
|
for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders {
|
||||||
callCountOnThisUpstream := upstreamAD.performRefreshCallCount
|
callCountOnThisUpstream := upstreamAD.performRefreshCallCount
|
||||||
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
actualCallCountAcrossAllUpstreams += callCountOnThisUpstream
|
||||||
if callCountOnThisUpstream == 1 {
|
actualNamesOfUpstreamWhichMadeCalls = append(actualNamesOfUpstreamWhichMadeCalls, upstreamAD.Name)
|
||||||
actualNameOfUpstreamWhichMadeCall = upstreamAD.Name
|
actualArgsOfAllCalls = append(actualArgsOfAllCalls, upstreamAD.performRefreshArgs[0])
|
||||||
actualArgs = upstreamAD.performRefreshArgs[0]
|
}
|
||||||
}
|
require.Equal(t, expectedNumberOfCalls, actualCallCountAcrossAllUpstreams,
|
||||||
|
"should have been exactly one call to PerformRefresh() by all upstreams")
|
||||||
|
for _, actualNameOfUpstreamWhichMadeCall := range actualNamesOfUpstreamWhichMadeCalls {
|
||||||
|
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||||
|
"PerformRefresh() was called on the wrong upstream at least once")
|
||||||
|
}
|
||||||
|
for _, actualArgs := range actualArgsOfAllCalls {
|
||||||
|
require.Equal(t, expectedArgs, actualArgs,
|
||||||
|
"PerformRefresh() was called with the wrong arguments at least once")
|
||||||
}
|
}
|
||||||
require.Equal(t, 1, actualCallCountAcrossAllUpstreams,
|
|
||||||
"should have been exactly one call to PerformRefresh() by all upstreams",
|
|
||||||
)
|
|
||||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
|
||||||
"PerformRefresh() was called on the wrong upstream",
|
|
||||||
)
|
|
||||||
require.Equal(t, expectedArgs, actualArgs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {
|
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user