Support revocation of access tokens in UpstreamOIDCIdentityProviderI
- Rename the RevokeRefreshToken() function to RevokeToken() and make it take the token type (refresh or access) as a new parameter. - This is a prefactor getting ready to support revocation of upstream access tokens in the garbage collection handler.
This commit is contained in:
parent
edd3547977
commit
b981055d31
@ -243,7 +243,7 @@ func (c *garbageCollectorController) revokeUpstreamOIDCRefreshToken(ctx context.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Revoke the upstream refresh token. This is a noop if the upstream provider does not offer a revocation endpoint.
|
// Revoke the upstream refresh token. This is a noop if the upstream provider does not offer a revocation endpoint.
|
||||||
err := foundOIDCIdentityProviderI.RevokeRefreshToken(ctx, customSessionData.OIDC.UpstreamRefreshToken)
|
err := foundOIDCIdentityProviderI.RevokeToken(ctx, customSessionData.OIDC.UpstreamRefreshToken, provider.RefreshTokenType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This could be a network failure, a 503 result which we should retry
|
// This could be a network failure, a 503 result which we should retry
|
||||||
// (see https://datatracker.ietf.org/doc/html/rfc7009#section-2.2.1),
|
// (see https://datatracker.ietf.org/doc/html/rfc7009#section-2.2.1),
|
||||||
|
@ -366,18 +366,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(nil)
|
WithRevokeTokenError(nil)
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// The upstream refresh token is only revoked for the active authcode session.
|
// The upstream refresh token is only revoked for the active authcode session.
|
||||||
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
|
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
|
||||||
"upstream-oidc-provider-name",
|
"upstream-oidc-provider-name",
|
||||||
&oidctestutil.RevokeRefreshTokenArgs{
|
&oidctestutil.RevokeTokenArgs{
|
||||||
Ctx: syncContext.Context,
|
Ctx: syncContext.Context,
|
||||||
RefreshToken: "fake-upstream-refresh-token",
|
Token: "fake-upstream-refresh-token",
|
||||||
|
TokenType: provider.RefreshTokenType,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -448,14 +449,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(nil)
|
WithRevokeTokenError(nil)
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// Nothing to revoke since we couldn't read the invalid secret.
|
// Nothing to revoke since we couldn't read the invalid secret.
|
||||||
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t)
|
idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
|
||||||
|
|
||||||
// The invalid authcode session secrets is still deleted because it is expired.
|
// The invalid authcode session secrets is still deleted because it is expired.
|
||||||
r.ElementsMatch(
|
r.ElementsMatch(
|
||||||
@ -524,14 +525,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(nil)
|
WithRevokeTokenError(nil)
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// Nothing to revoke since we couldn't find the upstream in the cache.
|
// Nothing to revoke since we couldn't find the upstream in the cache.
|
||||||
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t)
|
idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
|
||||||
|
|
||||||
// The authcode session secrets is still deleted because it is expired.
|
// The authcode session secrets is still deleted because it is expired.
|
||||||
r.ElementsMatch(
|
r.ElementsMatch(
|
||||||
@ -600,14 +601,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(nil)
|
WithRevokeTokenError(nil)
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// Nothing to revoke since we couldn't find the upstream in the cache.
|
// Nothing to revoke since we couldn't find the upstream in the cache.
|
||||||
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t)
|
idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
|
||||||
|
|
||||||
// The authcode session secrets is still deleted because it is expired.
|
// The authcode session secrets is still deleted because it is expired.
|
||||||
r.ElementsMatch(
|
r.ElementsMatch(
|
||||||
@ -677,18 +678,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
|
WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// Tried to revoke it, although this revocation will fail.
|
// Tried to revoke it, although this revocation will fail.
|
||||||
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
|
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
|
||||||
"upstream-oidc-provider-name",
|
"upstream-oidc-provider-name",
|
||||||
&oidctestutil.RevokeRefreshTokenArgs{
|
&oidctestutil.RevokeTokenArgs{
|
||||||
Ctx: syncContext.Context,
|
Ctx: syncContext.Context,
|
||||||
RefreshToken: "fake-upstream-refresh-token",
|
Token: "fake-upstream-refresh-token",
|
||||||
|
TokenType: provider.RefreshTokenType,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -749,18 +751,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
|
WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// Tried to revoke it, although this revocation will fail.
|
// Tried to revoke it, although this revocation will fail.
|
||||||
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
|
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
|
||||||
"upstream-oidc-provider-name",
|
"upstream-oidc-provider-name",
|
||||||
&oidctestutil.RevokeRefreshTokenArgs{
|
&oidctestutil.RevokeTokenArgs{
|
||||||
Ctx: syncContext.Context,
|
Ctx: syncContext.Context,
|
||||||
RefreshToken: "fake-upstream-refresh-token",
|
Token: "fake-upstream-refresh-token",
|
||||||
|
TokenType: provider.RefreshTokenType,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -875,18 +878,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(nil)
|
WithRevokeTokenError(nil)
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// The upstream refresh token is only revoked for the downstream session which had offline_access granted.
|
// The upstream refresh token is only revoked for the downstream session which had offline_access granted.
|
||||||
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
|
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
|
||||||
"upstream-oidc-provider-name",
|
"upstream-oidc-provider-name",
|
||||||
&oidctestutil.RevokeRefreshTokenArgs{
|
&oidctestutil.RevokeTokenArgs{
|
||||||
Ctx: syncContext.Context,
|
Ctx: syncContext.Context,
|
||||||
RefreshToken: "fake-upstream-refresh-token",
|
Token: "fake-upstream-refresh-token",
|
||||||
|
TokenType: provider.RefreshTokenType,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -958,18 +962,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
|
||||||
WithName("upstream-oidc-provider-name").
|
WithName("upstream-oidc-provider-name").
|
||||||
WithResourceUID("upstream-oidc-provider-uid").
|
WithResourceUID("upstream-oidc-provider-uid").
|
||||||
WithRevokeRefreshTokenError(nil)
|
WithRevokeTokenError(nil)
|
||||||
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
|
||||||
|
|
||||||
startInformersAndController(idpListerBuilder.Build())
|
startInformersAndController(idpListerBuilder.Build())
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
|
|
||||||
// The upstream refresh token is revoked.
|
// The upstream refresh token is revoked.
|
||||||
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
|
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
|
||||||
"upstream-oidc-provider-name",
|
"upstream-oidc-provider-name",
|
||||||
&oidctestutil.RevokeRefreshTokenArgs{
|
&oidctestutil.RevokeTokenArgs{
|
||||||
Ctx: syncContext.Context,
|
Ctx: syncContext.Context,
|
||||||
RefreshToken: "fake-upstream-refresh-token",
|
Token: "fake-upstream-refresh-token",
|
||||||
|
TokenType: provider.RefreshTokenType,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1015,7 +1020,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
|
|||||||
r.False(syncContext.Queue.(*testQueue).called)
|
r.False(syncContext.Queue.(*testQueue).called)
|
||||||
|
|
||||||
// Run sync again when not enough time has passed since the most recent run, so no delete
|
// Run sync again when not enough time has passed since the most recent run, so no delete
|
||||||
// operations should happen even though there is a expired secret now.
|
// operations should happen even though there is an expired secret now.
|
||||||
fakeClock.Step(29 * time.Second)
|
fakeClock.Step(29 * time.Second)
|
||||||
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
|
||||||
require.Empty(t, kubeClient.Actions())
|
require.Empty(t, kubeClient.Actions())
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
provider "go.pinniped.dev/internal/oidc/provider"
|
||||||
nonce "go.pinniped.dev/pkg/oidcclient/nonce"
|
nonce "go.pinniped.dev/pkg/oidcclient/nonce"
|
||||||
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
|
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||||
pkce "go.pinniped.dev/pkg/oidcclient/pkce"
|
pkce "go.pinniped.dev/pkg/oidcclient/pkce"
|
||||||
@ -215,18 +216,18 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, ar
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRefreshToken mocks base method.
|
// RevokeToken mocks base method.
|
||||||
func (m *MockUpstreamOIDCIdentityProviderI) RevokeRefreshToken(arg0 context.Context, arg1 string) error {
|
func (m *MockUpstreamOIDCIdentityProviderI) RevokeToken(arg0 context.Context, arg1 string, arg2 provider.RevocableTokenType) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RevokeRefreshToken", arg0, arg1)
|
ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRefreshToken indicates an expected call of RevokeRefreshToken.
|
// RevokeToken indicates an expected call of RevokeToken.
|
||||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeToken(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeRefreshToken), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeToken), arg0, arg1, arg2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateToken mocks base method.
|
// ValidateToken mocks base method.
|
||||||
|
@ -17,6 +17,14 @@ import (
|
|||||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RevocableTokenType string
|
||||||
|
|
||||||
|
// These strings correspond to the token types defined by https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
|
||||||
|
const (
|
||||||
|
RefreshTokenType RevocableTokenType = "refresh_token"
|
||||||
|
AccessTokenType RevocableTokenType = "access_token"
|
||||||
|
)
|
||||||
|
|
||||||
type UpstreamOIDCIdentityProviderI interface {
|
type UpstreamOIDCIdentityProviderI interface {
|
||||||
// GetName returns a name for this upstream provider, which will be used as a component of the path for the
|
// GetName returns a name for this upstream provider, which will be used as a component of the path for the
|
||||||
// callback endpoint hosted by the Supervisor.
|
// callback endpoint hosted by the Supervisor.
|
||||||
@ -68,8 +76,8 @@ type UpstreamOIDCIdentityProviderI interface {
|
|||||||
// validate the ID token.
|
// validate the ID token.
|
||||||
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
|
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
|
||||||
|
|
||||||
// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint.
|
// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
|
||||||
RevokeRefreshToken(ctx context.Context, refreshToken string) error
|
RevokeToken(ctx context.Context, token string, tokenType RevocableTokenType) error
|
||||||
|
|
||||||
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
|
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
|
||||||
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
|
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
|
||||||
|
@ -68,11 +68,12 @@ type PerformRefreshArgs struct {
|
|||||||
ExpectedSubject string
|
ExpectedSubject string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRefreshTokenArgs is used to spy on calls to
|
// RevokeTokenArgs is used to spy on calls to
|
||||||
// TestUpstreamOIDCIdentityProvider.RevokeRefreshTokenArgsFunc().
|
// TestUpstreamOIDCIdentityProvider.RevokeTokenArgsFunc().
|
||||||
type RevokeRefreshTokenArgs struct {
|
type RevokeTokenArgs struct {
|
||||||
Ctx context.Context
|
Ctx context.Context
|
||||||
RefreshToken string
|
Token string
|
||||||
|
TokenType provider.RevocableTokenType
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateTokenArgs is used to spy on calls to
|
// ValidateTokenArgs is used to spy on calls to
|
||||||
@ -166,7 +167,7 @@ type TestUpstreamOIDCIdentityProvider struct {
|
|||||||
|
|
||||||
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
|
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
|
||||||
|
|
||||||
RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error
|
RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error
|
||||||
|
|
||||||
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
|
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
|
||||||
|
|
||||||
@ -176,8 +177,8 @@ type TestUpstreamOIDCIdentityProvider struct {
|
|||||||
passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs
|
passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs
|
||||||
performRefreshCallCount int
|
performRefreshCallCount int
|
||||||
performRefreshArgs []*PerformRefreshArgs
|
performRefreshArgs []*PerformRefreshArgs
|
||||||
revokeRefreshTokenCallCount int
|
revokeTokenCallCount int
|
||||||
revokeRefreshTokenArgs []*RevokeRefreshTokenArgs
|
revokeTokenArgs []*RevokeTokenArgs
|
||||||
validateTokenCallCount int
|
validateTokenCallCount int
|
||||||
validateTokenArgs []*ValidateTokenArgs
|
validateTokenArgs []*ValidateTokenArgs
|
||||||
}
|
}
|
||||||
@ -278,16 +279,17 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
|
|||||||
return u.PerformRefreshFunc(ctx, refreshToken)
|
return u.PerformRefreshFunc(ctx, refreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
|
func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error {
|
||||||
if u.revokeRefreshTokenArgs == nil {
|
if u.revokeTokenArgs == nil {
|
||||||
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0)
|
u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
|
||||||
}
|
}
|
||||||
u.revokeRefreshTokenCallCount++
|
u.revokeTokenCallCount++
|
||||||
u.revokeRefreshTokenArgs = append(u.revokeRefreshTokenArgs, &RevokeRefreshTokenArgs{
|
u.revokeTokenArgs = append(u.revokeTokenArgs, &RevokeTokenArgs{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
RefreshToken: refreshToken,
|
Token: token,
|
||||||
|
TokenType: tokenType,
|
||||||
})
|
})
|
||||||
return u.RevokeRefreshTokenFunc(ctx, refreshToken)
|
return u.RevokeTokenFunc(ctx, token, tokenType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int {
|
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int {
|
||||||
@ -301,15 +303,15 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *Perform
|
|||||||
return u.performRefreshArgs[call]
|
return u.performRefreshArgs[call]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenCallCount() int {
|
func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenCallCount() int {
|
||||||
return u.performRefreshCallCount
|
return u.performRefreshCallCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *RevokeRefreshTokenArgs {
|
func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenArgs(call int) *RevokeTokenArgs {
|
||||||
if u.revokeRefreshTokenArgs == nil {
|
if u.revokeTokenArgs == nil {
|
||||||
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0)
|
u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
|
||||||
}
|
}
|
||||||
return u.revokeRefreshTokenArgs[call]
|
return u.revokeTokenArgs[call]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
||||||
@ -552,40 +554,40 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeRefreshToken(
|
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeToken(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
expectedPerformedByUpstreamName string,
|
expectedPerformedByUpstreamName string,
|
||||||
expectedArgs *RevokeRefreshTokenArgs,
|
expectedArgs *RevokeTokenArgs,
|
||||||
) {
|
) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
var actualArgs *RevokeRefreshTokenArgs
|
var actualArgs *RevokeTokenArgs
|
||||||
var actualNameOfUpstreamWhichMadeCall string
|
var actualNameOfUpstreamWhichMadeCall string
|
||||||
actualCallCountAcrossAllOIDCUpstreams := 0
|
actualCallCountAcrossAllOIDCUpstreams := 0
|
||||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||||
callCountOnThisUpstream := upstreamOIDC.revokeRefreshTokenCallCount
|
callCountOnThisUpstream := upstreamOIDC.revokeTokenCallCount
|
||||||
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
|
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
|
||||||
if callCountOnThisUpstream == 1 {
|
if callCountOnThisUpstream == 1 {
|
||||||
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
||||||
actualArgs = upstreamOIDC.revokeRefreshTokenArgs[0]
|
actualArgs = upstreamOIDC.revokeTokenArgs[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
|
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
|
||||||
"should have been exactly one call to RevokeRefreshToken() by all OIDC upstreams",
|
"should have been exactly one call to RevokeToken() by all OIDC upstreams",
|
||||||
)
|
)
|
||||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||||
"RevokeRefreshToken() was called on the wrong OIDC upstream",
|
"RevokeToken() was called on the wrong OIDC upstream",
|
||||||
)
|
)
|
||||||
require.Equal(t, expectedArgs, actualArgs)
|
require.Equal(t, expectedArgs, actualArgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeRefreshToken(t *testing.T) {
|
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeToken(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
actualCallCountAcrossAllOIDCUpstreams := 0
|
actualCallCountAcrossAllOIDCUpstreams := 0
|
||||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||||
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeRefreshTokenCallCount
|
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeTokenCallCount
|
||||||
}
|
}
|
||||||
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
|
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
|
||||||
"expected exactly zero calls to RevokeRefreshToken()",
|
"expected exactly zero calls to RevokeToken()",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -610,7 +612,7 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
|
|||||||
authcodeExchangeErr error
|
authcodeExchangeErr error
|
||||||
passwordGrantErr error
|
passwordGrantErr error
|
||||||
performRefreshErr error
|
performRefreshErr error
|
||||||
revokeRefreshTokenErr error
|
revokeTokenErr error
|
||||||
validateTokenErr error
|
validateTokenErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -727,8 +729,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err err
|
|||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeRefreshTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
|
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
|
||||||
u.revokeRefreshTokenErr = err
|
u.revokeTokenErr = err
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -761,8 +763,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
|
|||||||
}
|
}
|
||||||
return u.refreshedTokens, nil
|
return u.refreshedTokens, nil
|
||||||
},
|
},
|
||||||
RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error {
|
RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error {
|
||||||
return u.revokeRefreshTokenErr
|
return u.revokeTokenErr
|
||||||
},
|
},
|
||||||
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
||||||
if u.validateTokenErr != nil {
|
if u.validateTokenErr != nil {
|
||||||
|
@ -137,32 +137,36 @@ func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string
|
|||||||
return p.Config.TokenSource(httpClientContext, &oauth2.Token{RefreshToken: refreshToken}).Token()
|
return p.Config.TokenSource(httpClientContext, &oauth2.Token{RefreshToken: refreshToken}).Token()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint.
|
// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
|
||||||
func (p *ProviderConfig) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
|
func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error {
|
||||||
if p.RevocationURL == nil {
|
if p.RevocationURL == nil {
|
||||||
plog.Trace("RevokeRefreshToken() was called but upstream provider has no available revocation endpoint", "providerName", p.Name)
|
plog.Trace("RevokeToken() was called but upstream provider has no available revocation endpoint",
|
||||||
|
"providerName", p.Name,
|
||||||
|
"tokenType", tokenType,
|
||||||
|
)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// First try using client auth in the request params.
|
// First try using client auth in the request params.
|
||||||
tryAnotherClientAuthMethod, err := p.tryRevokeRefreshToken(ctx, refreshToken, false)
|
tryAnotherClientAuthMethod, err := p.tryRevokeToken(ctx, token, tokenType, false)
|
||||||
if tryAnotherClientAuthMethod {
|
if tryAnotherClientAuthMethod {
|
||||||
// Try again using basic auth this time. Overwrite the first client auth error,
|
// Try again using basic auth this time. Overwrite the first client auth error,
|
||||||
// which isn't useful anymore when retrying.
|
// which isn't useful anymore when retrying.
|
||||||
_, err = p.tryRevokeRefreshToken(ctx, refreshToken, true)
|
_, err = p.tryRevokeToken(ctx, token, tokenType, true)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// tryRevokeRefreshToken will call the revocation endpoint using either basic auth or by including
|
// tryRevokeToken will call the revocation endpoint using either basic auth or by including
|
||||||
// client auth in the request params. It will return an error when the request failed. If the
|
// client auth in the request params. It will return an error when the request failed. If the
|
||||||
// request failed for a reason that might be due to bad client auth, then it will return true
|
// request failed for a reason that might be due to bad client auth, then it will return true
|
||||||
// for the tryAnotherClientAuthMethod return value, indicating that it might be worth trying
|
// for the tryAnotherClientAuthMethod return value, indicating that it might be worth trying
|
||||||
// again using the other client auth method.
|
// again using the other client auth method.
|
||||||
// RFC 7009 defines how to make a revocation request and how to interpret the response.
|
// RFC 7009 defines how to make a revocation request and how to interpret the response.
|
||||||
// See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 for details.
|
// See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 for details.
|
||||||
func (p *ProviderConfig) tryRevokeRefreshToken(
|
func (p *ProviderConfig) tryRevokeToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
refreshToken string,
|
token string,
|
||||||
|
tokenType provider.RevocableTokenType,
|
||||||
useBasicAuth bool,
|
useBasicAuth bool,
|
||||||
) (tryAnotherClientAuthMethod bool, err error) {
|
) (tryAnotherClientAuthMethod bool, err error) {
|
||||||
clientID := p.Config.ClientID
|
clientID := p.Config.ClientID
|
||||||
@ -171,8 +175,8 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
|
|||||||
httpClient := p.Client
|
httpClient := p.Client
|
||||||
|
|
||||||
params := url.Values{
|
params := url.Values{
|
||||||
"token": []string{refreshToken},
|
"token": []string{token},
|
||||||
"token_type_hint": []string{"refresh_token"},
|
"token_type_hint": []string{string(tokenType)},
|
||||||
}
|
}
|
||||||
if !useBasicAuth {
|
if !useBasicAuth {
|
||||||
params["client_id"] = []string{clientID}
|
params["client_id"] = []string{clientID}
|
||||||
@ -200,11 +204,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
|
|||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
case http.StatusOK:
|
case http.StatusOK:
|
||||||
// Success!
|
// Success!
|
||||||
plog.Trace("RevokeRefreshToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
|
plog.Trace("RevokeToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
|
||||||
return false, nil
|
return false, nil
|
||||||
case http.StatusBadRequest:
|
case http.StatusBadRequest:
|
||||||
// Bad request might be due to bad client auth method. Try to detect that.
|
// Bad request might be due to bad client auth method. Try to detect that.
|
||||||
plog.Trace("RevokeRefreshToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
|
plog.Trace("RevokeToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false,
|
return false,
|
||||||
@ -227,11 +231,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
|
|||||||
}
|
}
|
||||||
// Got an "invalid_client" response, which might mean client auth failed, so it may be worth trying again
|
// Got an "invalid_client" response, which might mean client auth failed, so it may be worth trying again
|
||||||
// using another client auth method. See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
// using another client auth method. See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||||
plog.Trace("RevokeRefreshToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
|
plog.Trace("RevokeToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
|
||||||
return true, err
|
return true, err
|
||||||
default:
|
default:
|
||||||
// Any other error is probably not due to failed client auth.
|
// Any other error is probably not due to failed client auth.
|
||||||
plog.Trace("RevokeRefreshToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode)
|
plog.Trace("RevokeToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode)
|
||||||
return false, fmt.Errorf("server responded with status %d", resp.StatusCode)
|
return false, fmt.Errorf("server responded with status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/mocks/mockkeyset"
|
"go.pinniped.dev/internal/mocks/mockkeyset"
|
||||||
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
"go.pinniped.dev/internal/testutil"
|
"go.pinniped.dev/internal/testutil"
|
||||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||||
@ -455,73 +456,114 @@ func TestProviderConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("RevokeRefreshToken", func(t *testing.T) {
|
t.Run("RevokeToken", func(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
nilRevocationURL bool
|
tokenType provider.RevocableTokenType
|
||||||
statusCodes []int
|
nilRevocationURL bool
|
||||||
returnErrBodies []string
|
statusCodes []int
|
||||||
wantErr string
|
returnErrBodies []string
|
||||||
wantNumRequests int
|
wantErr string
|
||||||
|
wantNumRequests int
|
||||||
|
wantTokenTypeHint string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "success without calling the server when there is no revocation URL set",
|
name: "success without calling the server when there is no revocation URL set for refresh token",
|
||||||
|
tokenType: provider.RefreshTokenType,
|
||||||
nilRevocationURL: true,
|
nilRevocationURL: true,
|
||||||
wantNumRequests: 0,
|
wantNumRequests: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success when the server returns 200 OK on the first call",
|
name: "success without calling the server when there is no revocation URL set for access token",
|
||||||
statusCodes: []int{http.StatusOK},
|
tokenType: provider.AccessTokenType,
|
||||||
wantNumRequests: 1,
|
nilRevocationURL: true,
|
||||||
|
wantNumRequests: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call",
|
name: "success when the server returns 200 OK on the first call for refresh token",
|
||||||
|
tokenType: provider.RefreshTokenType,
|
||||||
|
statusCodes: []int{http.StatusOK},
|
||||||
|
wantNumRequests: 1,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success when the server returns 200 OK on the first call for access token",
|
||||||
|
tokenType: provider.AccessTokenType,
|
||||||
|
statusCodes: []int{http.StatusOK},
|
||||||
|
wantNumRequests: 1,
|
||||||
|
wantTokenTypeHint: "access_token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for refresh token",
|
||||||
|
tokenType: provider.RefreshTokenType,
|
||||||
statusCodes: []int{http.StatusBadRequest, http.StatusOK},
|
statusCodes: []int{http.StatusBadRequest, http.StatusOK},
|
||||||
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
|
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
|
||||||
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
|
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
|
||||||
wantNumRequests: 2,
|
wantNumRequests: 2,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call",
|
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for access token",
|
||||||
statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
|
tokenType: provider.AccessTokenType,
|
||||||
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
|
statusCodes: []int{http.StatusBadRequest, http.StatusOK},
|
||||||
wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`,
|
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
|
||||||
wantNumRequests: 2,
|
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
|
||||||
|
wantNumRequests: 2,
|
||||||
|
wantTokenTypeHint: "access_token",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error when the server returns 400 Bad Request with bad JSON body on the first call",
|
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call",
|
||||||
statusCodes: []int{http.StatusBadRequest},
|
tokenType: provider.RefreshTokenType,
|
||||||
returnErrBodies: []string{`invalid JSON body`},
|
statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
|
||||||
wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`,
|
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
|
||||||
wantNumRequests: 1,
|
wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`,
|
||||||
|
wantNumRequests: 2,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error when the server returns 400 Bad Request with empty body",
|
name: "error when the server returns 400 Bad Request with bad JSON body on the first call",
|
||||||
statusCodes: []int{http.StatusBadRequest},
|
tokenType: provider.RefreshTokenType,
|
||||||
returnErrBodies: []string{``},
|
statusCodes: []int{http.StatusBadRequest},
|
||||||
wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`,
|
returnErrBodies: []string{`invalid JSON body`},
|
||||||
wantNumRequests: 1,
|
wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`,
|
||||||
|
wantNumRequests: 1,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call",
|
name: "error when the server returns 400 Bad Request with empty body",
|
||||||
statusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
|
tokenType: provider.RefreshTokenType,
|
||||||
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
|
statusCodes: []int{http.StatusBadRequest},
|
||||||
wantErr: "server responded with status 403",
|
returnErrBodies: []string{``},
|
||||||
wantNumRequests: 2,
|
wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`,
|
||||||
|
wantNumRequests: 1,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error when server returns any other 400 error on first call",
|
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call",
|
||||||
statusCodes: []int{http.StatusBadRequest},
|
tokenType: provider.RefreshTokenType,
|
||||||
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
|
statusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
|
||||||
wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`,
|
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
|
||||||
wantNumRequests: 1,
|
wantErr: "server responded with status 403",
|
||||||
|
wantNumRequests: 2,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error when server returns any other error aside from 400 on first call",
|
name: "error when server returns any other 400 error on first call",
|
||||||
statusCodes: []int{http.StatusForbidden},
|
tokenType: provider.RefreshTokenType,
|
||||||
returnErrBodies: []string{""},
|
statusCodes: []int{http.StatusBadRequest},
|
||||||
wantErr: "server responded with status 403",
|
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
|
||||||
wantNumRequests: 1,
|
wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`,
|
||||||
|
wantNumRequests: 1,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error when server returns any other error aside from 400 on first call",
|
||||||
|
tokenType: provider.RefreshTokenType,
|
||||||
|
statusCodes: []int{http.StatusForbidden},
|
||||||
|
returnErrBodies: []string{""},
|
||||||
|
wantErr: "server responded with status 403",
|
||||||
|
wantNumRequests: 1,
|
||||||
|
wantTokenTypeHint: "refresh_token",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -536,15 +578,15 @@ func TestProviderConfig(t *testing.T) {
|
|||||||
if numRequests == 1 {
|
if numRequests == 1 {
|
||||||
// First request should use client_id/client_secret params.
|
// First request should use client_id/client_secret params.
|
||||||
require.Equal(t, 4, len(r.Form))
|
require.Equal(t, 4, len(r.Form))
|
||||||
|
require.Equal(t, "test-upstream-token", r.Form.Get("token"))
|
||||||
|
require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint"))
|
||||||
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
||||||
require.Equal(t, "test-client-secret", r.Form.Get("client_secret"))
|
require.Equal(t, "test-client-secret", r.Form.Get("client_secret"))
|
||||||
require.Equal(t, "refresh_token", r.Form.Get("token_type_hint"))
|
|
||||||
require.Equal(t, "test-initial-refresh-token", r.Form.Get("token"))
|
|
||||||
} else {
|
} else {
|
||||||
// Second request, if there is one, should use basic auth.
|
// Second request, if there is one, should use basic auth.
|
||||||
require.Equal(t, 2, len(r.Form))
|
require.Equal(t, 2, len(r.Form))
|
||||||
require.Equal(t, "refresh_token", r.Form.Get("token_type_hint"))
|
require.Equal(t, "test-upstream-token", r.Form.Get("token"))
|
||||||
require.Equal(t, "test-initial-refresh-token", r.Form.Get("token"))
|
require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint"))
|
||||||
username, password, hasBasicAuth := r.BasicAuth()
|
username, password, hasBasicAuth := r.BasicAuth()
|
||||||
require.True(t, hasBasicAuth, "request should have had basic auth but did not")
|
require.True(t, hasBasicAuth, "request should have had basic auth but did not")
|
||||||
require.Equal(t, "test-client-id", username)
|
require.Equal(t, "test-client-id", username)
|
||||||
@ -574,9 +616,10 @@ func TestProviderConfig(t *testing.T) {
|
|||||||
p.RevocationURL = nil
|
p.RevocationURL = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = p.RevokeRefreshToken(
|
err = p.RevokeToken(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
"test-initial-refresh-token",
|
"test-upstream-token",
|
||||||
|
tt.tokenType,
|
||||||
)
|
)
|
||||||
|
|
||||||
require.Equal(t, tt.wantNumRequests, numRequests,
|
require.Equal(t, tt.wantNumRequests, numRequests,
|
||||||
|
Loading…
Reference in New Issue
Block a user