Support revocation of access tokens in UpstreamOIDCIdentityProviderI

- Rename the RevokeRefreshToken() function to RevokeToken() and make it
  take the token type (refresh or access) as a new parameter.
- This is a prefactor getting ready to support revocation of upstream
  access tokens in the garbage collection handler.
This commit is contained in:
Ryan Richard 2021-12-03 13:44:24 -08:00
parent edd3547977
commit b981055d31
7 changed files with 204 additions and 141 deletions

View File

@ -243,7 +243,7 @@ func (c *garbageCollectorController) revokeUpstreamOIDCRefreshToken(ctx context.
} }
// Revoke the upstream refresh token. This is a noop if the upstream provider does not offer a revocation endpoint. // Revoke the upstream refresh token. This is a noop if the upstream provider does not offer a revocation endpoint.
err := foundOIDCIdentityProviderI.RevokeRefreshToken(ctx, customSessionData.OIDC.UpstreamRefreshToken) err := foundOIDCIdentityProviderI.RevokeToken(ctx, customSessionData.OIDC.UpstreamRefreshToken, provider.RefreshTokenType)
if err != nil { if err != nil {
// This could be a network failure, a 503 result which we should retry // This could be a network failure, a 503 result which we should retry
// (see https://datatracker.ietf.org/doc/html/rfc7009#section-2.2.1), // (see https://datatracker.ietf.org/doc/html/rfc7009#section-2.2.1),

View File

@ -366,18 +366,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil) WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// The upstream refresh token is only revoked for the active authcode session. // The upstream refresh token is only revoked for the active authcode session.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name", "upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
}, },
) )
@ -448,14 +449,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil) WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Nothing to revoke since we couldn't read the invalid secret. // Nothing to revoke since we couldn't read the invalid secret.
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t) idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
// The invalid authcode session secrets is still deleted because it is expired. // The invalid authcode session secrets is still deleted because it is expired.
r.ElementsMatch( r.ElementsMatch(
@ -524,14 +525,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil) WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Nothing to revoke since we couldn't find the upstream in the cache. // Nothing to revoke since we couldn't find the upstream in the cache.
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t) idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
// The authcode session secrets is still deleted because it is expired. // The authcode session secrets is still deleted because it is expired.
r.ElementsMatch( r.ElementsMatch(
@ -600,14 +601,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil) WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Nothing to revoke since we couldn't find the upstream in the cache. // Nothing to revoke since we couldn't find the upstream in the cache.
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t) idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
// The authcode session secrets is still deleted because it is expired. // The authcode session secrets is still deleted because it is expired.
r.ElementsMatch( r.ElementsMatch(
@ -677,18 +678,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Tried to revoke it, although this revocation will fail. // Tried to revoke it, although this revocation will fail.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name", "upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
}, },
) )
@ -749,18 +751,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Tried to revoke it, although this revocation will fail. // Tried to revoke it, although this revocation will fail.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name", "upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
}, },
) )
@ -875,18 +878,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil) WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// The upstream refresh token is only revoked for the downstream session which had offline_access granted. // The upstream refresh token is only revoked for the downstream session which had offline_access granted.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name", "upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
}, },
) )
@ -958,18 +962,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name"). WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid"). WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil) WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build()) idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build()) startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// The upstream refresh token is revoked. // The upstream refresh token is revoked.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t, idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name", "upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
}, },
) )
@ -1015,7 +1020,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
r.False(syncContext.Queue.(*testQueue).called) r.False(syncContext.Queue.(*testQueue).called)
// Run sync again when not enough time has passed since the most recent run, so no delete // Run sync again when not enough time has passed since the most recent run, so no delete
// operations should happen even though there is a expired secret now. // operations should happen even though there is an expired secret now.
fakeClock.Step(29 * time.Second) fakeClock.Step(29 * time.Second)
r.NoError(controllerlib.TestSync(t, subject, *syncContext)) r.NoError(controllerlib.TestSync(t, subject, *syncContext))
require.Empty(t, kubeClient.Actions()) require.Empty(t, kubeClient.Actions())

View File

@ -14,6 +14,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
provider "go.pinniped.dev/internal/oidc/provider"
nonce "go.pinniped.dev/pkg/oidcclient/nonce" nonce "go.pinniped.dev/pkg/oidcclient/nonce"
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes" oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
pkce "go.pinniped.dev/pkg/oidcclient/pkce" pkce "go.pinniped.dev/pkg/oidcclient/pkce"
@ -215,18 +216,18 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, ar
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1)
} }
// RevokeRefreshToken mocks base method. // RevokeToken mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) RevokeRefreshToken(arg0 context.Context, arg1 string) error { func (m *MockUpstreamOIDCIdentityProviderI) RevokeToken(arg0 context.Context, arg1 string, arg2 provider.RevocableTokenType) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeRefreshToken", arg0, arg1) ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// RevokeRefreshToken indicates an expected call of RevokeRefreshToken. // RevokeToken indicates an expected call of RevokeToken.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0, arg1 interface{}) *gomock.Call { func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeToken(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeRefreshToken), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeToken), arg0, arg1, arg2)
} }
// ValidateToken mocks base method. // ValidateToken mocks base method.

View File

@ -17,6 +17,14 @@ import (
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
type RevocableTokenType string
// These strings correspond to the token types defined by https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
const (
RefreshTokenType RevocableTokenType = "refresh_token"
AccessTokenType RevocableTokenType = "access_token"
)
type UpstreamOIDCIdentityProviderI interface { type UpstreamOIDCIdentityProviderI interface {
// GetName returns a name for this upstream provider, which will be used as a component of the path for the // GetName returns a name for this upstream provider, which will be used as a component of the path for the
// callback endpoint hosted by the Supervisor. // callback endpoint hosted by the Supervisor.
@ -68,8 +76,8 @@ type UpstreamOIDCIdentityProviderI interface {
// validate the ID token. // validate the ID token.
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint. // RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
RevokeRefreshToken(ctx context.Context, refreshToken string) error RevokeToken(ctx context.Context, token string, tokenType RevocableTokenType) error
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response // ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated // into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated

View File

@ -68,11 +68,12 @@ type PerformRefreshArgs struct {
ExpectedSubject string ExpectedSubject string
} }
// RevokeRefreshTokenArgs is used to spy on calls to // RevokeTokenArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.RevokeRefreshTokenArgsFunc(). // TestUpstreamOIDCIdentityProvider.RevokeTokenArgsFunc().
type RevokeRefreshTokenArgs struct { type RevokeTokenArgs struct {
Ctx context.Context Ctx context.Context
RefreshToken string Token string
TokenType provider.RevocableTokenType
} }
// ValidateTokenArgs is used to spy on calls to // ValidateTokenArgs is used to spy on calls to
@ -166,7 +167,7 @@ type TestUpstreamOIDCIdentityProvider struct {
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error) PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
@ -176,8 +177,8 @@ type TestUpstreamOIDCIdentityProvider struct {
passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs
performRefreshCallCount int performRefreshCallCount int
performRefreshArgs []*PerformRefreshArgs performRefreshArgs []*PerformRefreshArgs
revokeRefreshTokenCallCount int revokeTokenCallCount int
revokeRefreshTokenArgs []*RevokeRefreshTokenArgs revokeTokenArgs []*RevokeTokenArgs
validateTokenCallCount int validateTokenCallCount int
validateTokenArgs []*ValidateTokenArgs validateTokenArgs []*ValidateTokenArgs
} }
@ -278,16 +279,17 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
return u.PerformRefreshFunc(ctx, refreshToken) return u.PerformRefreshFunc(ctx, refreshToken)
} }
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshToken(ctx context.Context, refreshToken string) error { func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error {
if u.revokeRefreshTokenArgs == nil { if u.revokeTokenArgs == nil {
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0) u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
} }
u.revokeRefreshTokenCallCount++ u.revokeTokenCallCount++
u.revokeRefreshTokenArgs = append(u.revokeRefreshTokenArgs, &RevokeRefreshTokenArgs{ u.revokeTokenArgs = append(u.revokeTokenArgs, &RevokeTokenArgs{
Ctx: ctx, Ctx: ctx,
RefreshToken: refreshToken, Token: token,
TokenType: tokenType,
}) })
return u.RevokeRefreshTokenFunc(ctx, refreshToken) return u.RevokeTokenFunc(ctx, token, tokenType)
} }
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int { func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int {
@ -301,15 +303,15 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *Perform
return u.performRefreshArgs[call] return u.performRefreshArgs[call]
} }
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenCallCount() int { func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenCallCount() int {
return u.performRefreshCallCount return u.performRefreshCallCount
} }
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *RevokeRefreshTokenArgs { func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenArgs(call int) *RevokeTokenArgs {
if u.revokeRefreshTokenArgs == nil { if u.revokeTokenArgs == nil {
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0) u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
} }
return u.revokeRefreshTokenArgs[call] return u.revokeTokenArgs[call]
} }
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
@ -552,40 +554,40 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes
) )
} }
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeRefreshToken( func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeToken(
t *testing.T, t *testing.T,
expectedPerformedByUpstreamName string, expectedPerformedByUpstreamName string,
expectedArgs *RevokeRefreshTokenArgs, expectedArgs *RevokeTokenArgs,
) { ) {
t.Helper() t.Helper()
var actualArgs *RevokeRefreshTokenArgs var actualArgs *RevokeTokenArgs
var actualNameOfUpstreamWhichMadeCall string var actualNameOfUpstreamWhichMadeCall string
actualCallCountAcrossAllOIDCUpstreams := 0 actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.revokeRefreshTokenCallCount callCountOnThisUpstream := upstreamOIDC.revokeTokenCallCount
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 { if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.revokeRefreshTokenArgs[0] actualArgs = upstreamOIDC.revokeTokenArgs[0]
} }
} }
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
"should have been exactly one call to RevokeRefreshToken() by all OIDC upstreams", "should have been exactly one call to RevokeToken() by all OIDC upstreams",
) )
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"RevokeRefreshToken() was called on the wrong OIDC upstream", "RevokeToken() was called on the wrong OIDC upstream",
) )
require.Equal(t, expectedArgs, actualArgs) require.Equal(t, expectedArgs, actualArgs)
} }
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeRefreshToken(t *testing.T) { func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeToken(t *testing.T) {
t.Helper() t.Helper()
actualCallCountAcrossAllOIDCUpstreams := 0 actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeRefreshTokenCallCount actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeTokenCallCount
} }
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
"expected exactly zero calls to RevokeRefreshToken()", "expected exactly zero calls to RevokeToken()",
) )
} }
@ -610,7 +612,7 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
authcodeExchangeErr error authcodeExchangeErr error
passwordGrantErr error passwordGrantErr error
performRefreshErr error performRefreshErr error
revokeRefreshTokenErr error revokeTokenErr error
validateTokenErr error validateTokenErr error
} }
@ -727,8 +729,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err err
return u return u
} }
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeRefreshTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
u.revokeRefreshTokenErr = err u.revokeTokenErr = err
return u return u
} }
@ -761,8 +763,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
} }
return u.refreshedTokens, nil return u.refreshedTokens, nil
}, },
RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error { RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error {
return u.revokeRefreshTokenErr return u.revokeTokenErr
}, },
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
if u.validateTokenErr != nil { if u.validateTokenErr != nil {

View File

@ -137,32 +137,36 @@ func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string
return p.Config.TokenSource(httpClientContext, &oauth2.Token{RefreshToken: refreshToken}).Token() return p.Config.TokenSource(httpClientContext, &oauth2.Token{RefreshToken: refreshToken}).Token()
} }
// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint. // RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
func (p *ProviderConfig) RevokeRefreshToken(ctx context.Context, refreshToken string) error { func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error {
if p.RevocationURL == nil { if p.RevocationURL == nil {
plog.Trace("RevokeRefreshToken() was called but upstream provider has no available revocation endpoint", "providerName", p.Name) plog.Trace("RevokeToken() was called but upstream provider has no available revocation endpoint",
"providerName", p.Name,
"tokenType", tokenType,
)
return nil return nil
} }
// First try using client auth in the request params. // First try using client auth in the request params.
tryAnotherClientAuthMethod, err := p.tryRevokeRefreshToken(ctx, refreshToken, false) tryAnotherClientAuthMethod, err := p.tryRevokeToken(ctx, token, tokenType, false)
if tryAnotherClientAuthMethod { if tryAnotherClientAuthMethod {
// Try again using basic auth this time. Overwrite the first client auth error, // Try again using basic auth this time. Overwrite the first client auth error,
// which isn't useful anymore when retrying. // which isn't useful anymore when retrying.
_, err = p.tryRevokeRefreshToken(ctx, refreshToken, true) _, err = p.tryRevokeToken(ctx, token, tokenType, true)
} }
return err return err
} }
// tryRevokeRefreshToken will call the revocation endpoint using either basic auth or by including // tryRevokeToken will call the revocation endpoint using either basic auth or by including
// client auth in the request params. It will return an error when the request failed. If the // client auth in the request params. It will return an error when the request failed. If the
// request failed for a reason that might be due to bad client auth, then it will return true // request failed for a reason that might be due to bad client auth, then it will return true
// for the tryAnotherClientAuthMethod return value, indicating that it might be worth trying // for the tryAnotherClientAuthMethod return value, indicating that it might be worth trying
// again using the other client auth method. // again using the other client auth method.
// RFC 7009 defines how to make a revocation request and how to interpret the response. // RFC 7009 defines how to make a revocation request and how to interpret the response.
// See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 for details. // See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 for details.
func (p *ProviderConfig) tryRevokeRefreshToken( func (p *ProviderConfig) tryRevokeToken(
ctx context.Context, ctx context.Context,
refreshToken string, token string,
tokenType provider.RevocableTokenType,
useBasicAuth bool, useBasicAuth bool,
) (tryAnotherClientAuthMethod bool, err error) { ) (tryAnotherClientAuthMethod bool, err error) {
clientID := p.Config.ClientID clientID := p.Config.ClientID
@ -171,8 +175,8 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
httpClient := p.Client httpClient := p.Client
params := url.Values{ params := url.Values{
"token": []string{refreshToken}, "token": []string{token},
"token_type_hint": []string{"refresh_token"}, "token_type_hint": []string{string(tokenType)},
} }
if !useBasicAuth { if !useBasicAuth {
params["client_id"] = []string{clientID} params["client_id"] = []string{clientID}
@ -200,11 +204,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:
// Success! // Success!
plog.Trace("RevokeRefreshToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth) plog.Trace("RevokeToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
return false, nil return false, nil
case http.StatusBadRequest: case http.StatusBadRequest:
// Bad request might be due to bad client auth method. Try to detect that. // Bad request might be due to bad client auth method. Try to detect that.
plog.Trace("RevokeRefreshToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth) plog.Trace("RevokeToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return false, return false,
@ -227,11 +231,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
} }
// Got an "invalid_client" response, which might mean client auth failed, so it may be worth trying again // Got an "invalid_client" response, which might mean client auth failed, so it may be worth trying again
// using another client auth method. See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 // using another client auth method. See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
plog.Trace("RevokeRefreshToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth) plog.Trace("RevokeToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
return true, err return true, err
default: default:
// Any other error is probably not due to failed client auth. // Any other error is probably not due to failed client auth.
plog.Trace("RevokeRefreshToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode) plog.Trace("RevokeToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode)
return false, fmt.Errorf("server responded with status %d", resp.StatusCode) return false, fmt.Errorf("server responded with status %d", resp.StatusCode)
} }
} }

View File

@ -24,6 +24,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/mocks/mockkeyset"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
@ -455,73 +456,114 @@ func TestProviderConfig(t *testing.T) {
} }
}) })
t.Run("RevokeRefreshToken", func(t *testing.T) { t.Run("RevokeToken", func(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
tokenType provider.RevocableTokenType
nilRevocationURL bool nilRevocationURL bool
statusCodes []int statusCodes []int
returnErrBodies []string returnErrBodies []string
wantErr string wantErr string
wantNumRequests int wantNumRequests int
wantTokenTypeHint string
}{ }{
{ {
name: "success without calling the server when there is no revocation URL set", name: "success without calling the server when there is no revocation URL set for refresh token",
tokenType: provider.RefreshTokenType,
nilRevocationURL: true, nilRevocationURL: true,
wantNumRequests: 0, wantNumRequests: 0,
}, },
{ {
name: "success when the server returns 200 OK on the first call", name: "success without calling the server when there is no revocation URL set for access token",
statusCodes: []int{http.StatusOK}, tokenType: provider.AccessTokenType,
wantNumRequests: 1, nilRevocationURL: true,
wantNumRequests: 0,
}, },
{ {
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call", name: "success when the server returns 200 OK on the first call for refresh token",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusOK},
wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
},
{
name: "success when the server returns 200 OK on the first call for access token",
tokenType: provider.AccessTokenType,
statusCodes: []int{http.StatusOK},
wantNumRequests: 1,
wantTokenTypeHint: "access_token",
},
{
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for refresh token",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusOK}, statusCodes: []int{http.StatusBadRequest, http.StatusOK},
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
wantNumRequests: 2, wantNumRequests: 2,
wantTokenTypeHint: "refresh_token",
},
{
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for access token",
tokenType: provider.AccessTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusOK},
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
wantNumRequests: 2,
wantTokenTypeHint: "access_token",
}, },
{ {
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call", name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest}, statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`, wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`,
wantNumRequests: 2, wantNumRequests: 2,
wantTokenTypeHint: "refresh_token",
}, },
{ {
name: "error when the server returns 400 Bad Request with bad JSON body on the first call", name: "error when the server returns 400 Bad Request with bad JSON body on the first call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest}, statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`invalid JSON body`}, returnErrBodies: []string{`invalid JSON body`},
wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`, wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
}, },
{ {
name: "error when the server returns 400 Bad Request with empty body", name: "error when the server returns 400 Bad Request with empty body",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest}, statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{``}, returnErrBodies: []string{``},
wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`, wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
}, },
{ {
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call", name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusForbidden}, statusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: "server responded with status 403", wantErr: "server responded with status 403",
wantNumRequests: 2, wantNumRequests: 2,
wantTokenTypeHint: "refresh_token",
}, },
{ {
name: "error when server returns any other 400 error on first call", name: "error when server returns any other 400 error on first call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest}, statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`, wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
}, },
{ {
name: "error when server returns any other error aside from 400 on first call", name: "error when server returns any other error aside from 400 on first call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusForbidden}, statusCodes: []int{http.StatusForbidden},
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: "server responded with status 403", wantErr: "server responded with status 403",
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -536,15 +578,15 @@ func TestProviderConfig(t *testing.T) {
if numRequests == 1 { if numRequests == 1 {
// First request should use client_id/client_secret params. // First request should use client_id/client_secret params.
require.Equal(t, 4, len(r.Form)) require.Equal(t, 4, len(r.Form))
require.Equal(t, "test-upstream-token", r.Form.Get("token"))
require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint"))
require.Equal(t, "test-client-id", r.Form.Get("client_id")) require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-client-secret", r.Form.Get("client_secret")) require.Equal(t, "test-client-secret", r.Form.Get("client_secret"))
require.Equal(t, "refresh_token", r.Form.Get("token_type_hint"))
require.Equal(t, "test-initial-refresh-token", r.Form.Get("token"))
} else { } else {
// Second request, if there is one, should use basic auth. // Second request, if there is one, should use basic auth.
require.Equal(t, 2, len(r.Form)) require.Equal(t, 2, len(r.Form))
require.Equal(t, "refresh_token", r.Form.Get("token_type_hint")) require.Equal(t, "test-upstream-token", r.Form.Get("token"))
require.Equal(t, "test-initial-refresh-token", r.Form.Get("token")) require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint"))
username, password, hasBasicAuth := r.BasicAuth() username, password, hasBasicAuth := r.BasicAuth()
require.True(t, hasBasicAuth, "request should have had basic auth but did not") require.True(t, hasBasicAuth, "request should have had basic auth but did not")
require.Equal(t, "test-client-id", username) require.Equal(t, "test-client-id", username)
@ -574,9 +616,10 @@ func TestProviderConfig(t *testing.T) {
p.RevocationURL = nil p.RevocationURL = nil
} }
err = p.RevokeRefreshToken( err = p.RevokeToken(
context.Background(), context.Background(),
"test-initial-refresh-token", "test-upstream-token",
tt.tokenType,
) )
require.Equal(t, tt.wantNumRequests, numRequests, require.Equal(t, tt.wantNumRequests, numRequests,