Support revocation of access tokens in UpstreamOIDCIdentityProviderI

- Rename the RevokeRefreshToken() function to RevokeToken() and make it
  take the token type (refresh or access) as a new parameter.
- This is a prefactor getting ready to support revocation of upstream
  access tokens in the garbage collection handler.
This commit is contained in:
Ryan Richard 2021-12-03 13:44:24 -08:00
parent edd3547977
commit b981055d31
7 changed files with 204 additions and 141 deletions

View File

@ -243,7 +243,7 @@ func (c *garbageCollectorController) revokeUpstreamOIDCRefreshToken(ctx context.
}
// Revoke the upstream refresh token. This is a noop if the upstream provider does not offer a revocation endpoint.
err := foundOIDCIdentityProviderI.RevokeRefreshToken(ctx, customSessionData.OIDC.UpstreamRefreshToken)
err := foundOIDCIdentityProviderI.RevokeToken(ctx, customSessionData.OIDC.UpstreamRefreshToken, provider.RefreshTokenType)
if err != nil {
// This could be a network failure, a 503 result which we should retry
// (see https://datatracker.ietf.org/doc/html/rfc7009#section-2.2.1),

View File

@ -366,18 +366,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil)
WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// The upstream refresh token is only revoked for the active authcode session.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{
Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token",
&oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
},
)
@ -448,14 +449,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil)
WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Nothing to revoke since we couldn't read the invalid secret.
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t)
idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
// The invalid authcode session secrets is still deleted because it is expired.
r.ElementsMatch(
@ -524,14 +525,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil)
WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Nothing to revoke since we couldn't find the upstream in the cache.
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t)
idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
// The authcode session secrets is still deleted because it is expired.
r.ElementsMatch(
@ -600,14 +601,14 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil)
WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Nothing to revoke since we couldn't find the upstream in the cache.
idpListerBuilder.RequireExactlyZeroCallsToRevokeRefreshToken(t)
idpListerBuilder.RequireExactlyZeroCallsToRevokeToken(t)
// The authcode session secrets is still deleted because it is expired.
r.ElementsMatch(
@ -677,18 +678,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Tried to revoke it, although this revocation will fail.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{
Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token",
&oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
},
)
@ -749,18 +751,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
WithRevokeTokenError(errors.New("some upstream revocation error")) // the upstream revocation will fail
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// Tried to revoke it, although this revocation will fail.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{
Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token",
&oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
},
)
@ -875,18 +878,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil)
WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// The upstream refresh token is only revoked for the downstream session which had offline_access granted.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{
Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token",
&oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
},
)
@ -958,18 +962,19 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
happyOIDCUpstream := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName("upstream-oidc-provider-name").
WithResourceUID("upstream-oidc-provider-uid").
WithRevokeRefreshTokenError(nil)
WithRevokeTokenError(nil)
idpListerBuilder := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyOIDCUpstream.Build())
startInformersAndController(idpListerBuilder.Build())
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
// The upstream refresh token is revoked.
idpListerBuilder.RequireExactlyOneCallToRevokeRefreshToken(t,
idpListerBuilder.RequireExactlyOneCallToRevokeToken(t,
"upstream-oidc-provider-name",
&oidctestutil.RevokeRefreshTokenArgs{
Ctx: syncContext.Context,
RefreshToken: "fake-upstream-refresh-token",
&oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType,
},
)
@ -1015,7 +1020,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
r.False(syncContext.Queue.(*testQueue).called)
// Run sync again when not enough time has passed since the most recent run, so no delete
// operations should happen even though there is a expired secret now.
// operations should happen even though there is an expired secret now.
fakeClock.Step(29 * time.Second)
r.NoError(controllerlib.TestSync(t, subject, *syncContext))
require.Empty(t, kubeClient.Actions())

View File

@ -14,6 +14,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
provider "go.pinniped.dev/internal/oidc/provider"
nonce "go.pinniped.dev/pkg/oidcclient/nonce"
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
pkce "go.pinniped.dev/pkg/oidcclient/pkce"
@ -215,18 +216,18 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, ar
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1)
}
// RevokeRefreshToken mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) RevokeRefreshToken(arg0 context.Context, arg1 string) error {
// RevokeToken mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) RevokeToken(arg0 context.Context, arg1 string, arg2 provider.RevocableTokenType) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeRefreshToken", arg0, arg1)
ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// RevokeRefreshToken indicates an expected call of RevokeRefreshToken.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeRefreshToken(arg0, arg1 interface{}) *gomock.Call {
// RevokeToken indicates an expected call of RevokeToken.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) RevokeToken(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeRefreshToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeRefreshToken), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevokeToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).RevokeToken), arg0, arg1, arg2)
}
// ValidateToken mocks base method.

View File

@ -17,6 +17,14 @@ import (
"go.pinniped.dev/pkg/oidcclient/pkce"
)
type RevocableTokenType string
// These strings correspond to the token types defined by https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
const (
RefreshTokenType RevocableTokenType = "refresh_token"
AccessTokenType RevocableTokenType = "access_token"
)
type UpstreamOIDCIdentityProviderI interface {
// GetName returns a name for this upstream provider, which will be used as a component of the path for the
// callback endpoint hosted by the Supervisor.
@ -68,8 +76,8 @@ type UpstreamOIDCIdentityProviderI interface {
// validate the ID token.
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint.
RevokeRefreshToken(ctx context.Context, refreshToken string) error
// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
RevokeToken(ctx context.Context, token string, tokenType RevocableTokenType) error
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated

View File

@ -68,11 +68,12 @@ type PerformRefreshArgs struct {
ExpectedSubject string
}
// RevokeRefreshTokenArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.RevokeRefreshTokenArgsFunc().
type RevokeRefreshTokenArgs struct {
Ctx context.Context
RefreshToken string
// RevokeTokenArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.RevokeTokenArgsFunc().
type RevokeTokenArgs struct {
Ctx context.Context
Token string
TokenType provider.RevocableTokenType
}
// ValidateTokenArgs is used to spy on calls to
@ -166,7 +167,7 @@ type TestUpstreamOIDCIdentityProvider struct {
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
RevokeRefreshTokenFunc func(ctx context.Context, refreshToken string) error
RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
@ -176,8 +177,8 @@ type TestUpstreamOIDCIdentityProvider struct {
passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs
performRefreshCallCount int
performRefreshArgs []*PerformRefreshArgs
revokeRefreshTokenCallCount int
revokeRefreshTokenArgs []*RevokeRefreshTokenArgs
revokeTokenCallCount int
revokeTokenArgs []*RevokeTokenArgs
validateTokenCallCount int
validateTokenArgs []*ValidateTokenArgs
}
@ -278,16 +279,17 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
return u.PerformRefreshFunc(ctx, refreshToken)
}
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
if u.revokeRefreshTokenArgs == nil {
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0)
func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error {
if u.revokeTokenArgs == nil {
u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
}
u.revokeRefreshTokenCallCount++
u.revokeRefreshTokenArgs = append(u.revokeRefreshTokenArgs, &RevokeRefreshTokenArgs{
Ctx: ctx,
RefreshToken: refreshToken,
u.revokeTokenCallCount++
u.revokeTokenArgs = append(u.revokeTokenArgs, &RevokeTokenArgs{
Ctx: ctx,
Token: token,
TokenType: tokenType,
})
return u.RevokeRefreshTokenFunc(ctx, refreshToken)
return u.RevokeTokenFunc(ctx, token, tokenType)
}
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int {
@ -301,15 +303,15 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *Perform
return u.performRefreshArgs[call]
}
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenCallCount() int {
func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenCallCount() int {
return u.performRefreshCallCount
}
func (u *TestUpstreamOIDCIdentityProvider) RevokeRefreshTokenArgs(call int) *RevokeRefreshTokenArgs {
if u.revokeRefreshTokenArgs == nil {
u.revokeRefreshTokenArgs = make([]*RevokeRefreshTokenArgs, 0)
func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenArgs(call int) *RevokeTokenArgs {
if u.revokeTokenArgs == nil {
u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
}
return u.revokeRefreshTokenArgs[call]
return u.revokeTokenArgs[call]
}
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
@ -552,40 +554,40 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *tes
)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeRefreshToken(
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeToken(
t *testing.T,
expectedPerformedByUpstreamName string,
expectedArgs *RevokeRefreshTokenArgs,
expectedArgs *RevokeTokenArgs,
) {
t.Helper()
var actualArgs *RevokeRefreshTokenArgs
var actualArgs *RevokeTokenArgs
var actualNameOfUpstreamWhichMadeCall string
actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.revokeRefreshTokenCallCount
callCountOnThisUpstream := upstreamOIDC.revokeTokenCallCount
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.revokeRefreshTokenArgs[0]
actualArgs = upstreamOIDC.revokeTokenArgs[0]
}
}
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
"should have been exactly one call to RevokeRefreshToken() by all OIDC upstreams",
"should have been exactly one call to RevokeToken() by all OIDC upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"RevokeRefreshToken() was called on the wrong OIDC upstream",
"RevokeToken() was called on the wrong OIDC upstream",
)
require.Equal(t, expectedArgs, actualArgs)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeRefreshToken(t *testing.T) {
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeToken(t *testing.T) {
t.Helper()
actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeRefreshTokenCallCount
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeTokenCallCount
}
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
"expected exactly zero calls to RevokeRefreshToken()",
"expected exactly zero calls to RevokeToken()",
)
}
@ -610,7 +612,7 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
authcodeExchangeErr error
passwordGrantErr error
performRefreshErr error
revokeRefreshTokenErr error
revokeTokenErr error
validateTokenErr error
}
@ -727,8 +729,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err err
return u
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeRefreshTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
u.revokeRefreshTokenErr = err
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
u.revokeTokenErr = err
return u
}
@ -761,8 +763,8 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
}
return u.refreshedTokens, nil
},
RevokeRefreshTokenFunc: func(ctx context.Context, refreshToken string) error {
return u.revokeRefreshTokenErr
RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error {
return u.revokeTokenErr
},
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
if u.validateTokenErr != nil {

View File

@ -137,32 +137,36 @@ func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string
return p.Config.TokenSource(httpClientContext, &oauth2.Token{RefreshToken: refreshToken}).Token()
}
// RevokeRefreshToken will attempt to revoke the given token, if the provider has a revocation endpoint.
func (p *ProviderConfig) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error {
if p.RevocationURL == nil {
plog.Trace("RevokeRefreshToken() was called but upstream provider has no available revocation endpoint", "providerName", p.Name)
plog.Trace("RevokeToken() was called but upstream provider has no available revocation endpoint",
"providerName", p.Name,
"tokenType", tokenType,
)
return nil
}
// First try using client auth in the request params.
tryAnotherClientAuthMethod, err := p.tryRevokeRefreshToken(ctx, refreshToken, false)
tryAnotherClientAuthMethod, err := p.tryRevokeToken(ctx, token, tokenType, false)
if tryAnotherClientAuthMethod {
// Try again using basic auth this time. Overwrite the first client auth error,
// which isn't useful anymore when retrying.
_, err = p.tryRevokeRefreshToken(ctx, refreshToken, true)
_, err = p.tryRevokeToken(ctx, token, tokenType, true)
}
return err
}
// tryRevokeRefreshToken will call the revocation endpoint using either basic auth or by including
// tryRevokeToken will call the revocation endpoint using either basic auth or by including
// client auth in the request params. It will return an error when the request failed. If the
// request failed for a reason that might be due to bad client auth, then it will return true
// for the tryAnotherClientAuthMethod return value, indicating that it might be worth trying
// again using the other client auth method.
// RFC 7009 defines how to make a revocation request and how to interpret the response.
// See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 for details.
func (p *ProviderConfig) tryRevokeRefreshToken(
func (p *ProviderConfig) tryRevokeToken(
ctx context.Context,
refreshToken string,
token string,
tokenType provider.RevocableTokenType,
useBasicAuth bool,
) (tryAnotherClientAuthMethod bool, err error) {
clientID := p.Config.ClientID
@ -171,8 +175,8 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
httpClient := p.Client
params := url.Values{
"token": []string{refreshToken},
"token_type_hint": []string{"refresh_token"},
"token": []string{token},
"token_type_hint": []string{string(tokenType)},
}
if !useBasicAuth {
params["client_id"] = []string{clientID}
@ -200,11 +204,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
switch resp.StatusCode {
case http.StatusOK:
// Success!
plog.Trace("RevokeRefreshToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
plog.Trace("RevokeToken() got 200 OK response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
return false, nil
case http.StatusBadRequest:
// Bad request might be due to bad client auth method. Try to detect that.
plog.Trace("RevokeRefreshToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
plog.Trace("RevokeToken() got 400 Bad Request response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
body, err := io.ReadAll(resp.Body)
if err != nil {
return false,
@ -227,11 +231,11 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
}
// Got an "invalid_client" response, which might mean client auth failed, so it may be worth trying again
// using another client auth method. See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
plog.Trace("RevokeRefreshToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
plog.Trace("RevokeToken()'s 400 Bad Request response from provider's revocation endpoint was type 'invalid_client'", "providerName", p.Name, "usedBasicAuth", useBasicAuth)
return true, err
default:
// Any other error is probably not due to failed client auth.
plog.Trace("RevokeRefreshToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode)
plog.Trace("RevokeToken() got unexpected error response from provider's revocation endpoint", "providerName", p.Name, "usedBasicAuth", useBasicAuth, "statusCode", resp.StatusCode)
return false, fmt.Errorf("server responded with status %d", resp.StatusCode)
}
}

View File

@ -24,6 +24,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/mocks/mockkeyset"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
@ -455,73 +456,114 @@ func TestProviderConfig(t *testing.T) {
}
})
t.Run("RevokeRefreshToken", func(t *testing.T) {
t.Run("RevokeToken", func(t *testing.T) {
tests := []struct {
name string
nilRevocationURL bool
statusCodes []int
returnErrBodies []string
wantErr string
wantNumRequests int
name string
tokenType provider.RevocableTokenType
nilRevocationURL bool
statusCodes []int
returnErrBodies []string
wantErr string
wantNumRequests int
wantTokenTypeHint string
}{
{
name: "success without calling the server when there is no revocation URL set",
name: "success without calling the server when there is no revocation URL set for refresh token",
tokenType: provider.RefreshTokenType,
nilRevocationURL: true,
wantNumRequests: 0,
},
{
name: "success when the server returns 200 OK on the first call",
statusCodes: []int{http.StatusOK},
wantNumRequests: 1,
name: "success without calling the server when there is no revocation URL set for access token",
tokenType: provider.AccessTokenType,
nilRevocationURL: true,
wantNumRequests: 0,
},
{
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call",
name: "success when the server returns 200 OK on the first call for refresh token",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusOK},
wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
},
{
name: "success when the server returns 200 OK on the first call for access token",
tokenType: provider.AccessTokenType,
statusCodes: []int{http.StatusOK},
wantNumRequests: 1,
wantTokenTypeHint: "access_token",
},
{
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for refresh token",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusOK},
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
wantNumRequests: 2,
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
wantNumRequests: 2,
wantTokenTypeHint: "refresh_token",
},
{
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call",
statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`,
wantNumRequests: 2,
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for access token",
tokenType: provider.AccessTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusOK},
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
wantNumRequests: 2,
wantTokenTypeHint: "access_token",
},
{
name: "error when the server returns 400 Bad Request with bad JSON body on the first call",
statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`invalid JSON body`},
wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`,
wantNumRequests: 1,
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`,
wantNumRequests: 2,
wantTokenTypeHint: "refresh_token",
},
{
name: "error when the server returns 400 Bad Request with empty body",
statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{``},
wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`,
wantNumRequests: 1,
name: "error when the server returns 400 Bad Request with bad JSON body on the first call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`invalid JSON body`},
wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`,
wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
},
{
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call",
statusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: "server responded with status 403",
wantNumRequests: 2,
name: "error when the server returns 400 Bad Request with empty body",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{``},
wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`,
wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
},
{
name: "error when server returns any other 400 error on first call",
statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`,
wantNumRequests: 1,
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: "server responded with status 403",
wantNumRequests: 2,
wantTokenTypeHint: "refresh_token",
},
{
name: "error when server returns any other error aside from 400 on first call",
statusCodes: []int{http.StatusForbidden},
returnErrBodies: []string{""},
wantErr: "server responded with status 403",
wantNumRequests: 1,
name: "error when server returns any other 400 error on first call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`,
wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
},
{
name: "error when server returns any other error aside from 400 on first call",
tokenType: provider.RefreshTokenType,
statusCodes: []int{http.StatusForbidden},
returnErrBodies: []string{""},
wantErr: "server responded with status 403",
wantNumRequests: 1,
wantTokenTypeHint: "refresh_token",
},
}
for _, tt := range tests {
@ -536,15 +578,15 @@ func TestProviderConfig(t *testing.T) {
if numRequests == 1 {
// First request should use client_id/client_secret params.
require.Equal(t, 4, len(r.Form))
require.Equal(t, "test-upstream-token", r.Form.Get("token"))
require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint"))
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-client-secret", r.Form.Get("client_secret"))
require.Equal(t, "refresh_token", r.Form.Get("token_type_hint"))
require.Equal(t, "test-initial-refresh-token", r.Form.Get("token"))
} else {
// Second request, if there is one, should use basic auth.
require.Equal(t, 2, len(r.Form))
require.Equal(t, "refresh_token", r.Form.Get("token_type_hint"))
require.Equal(t, "test-initial-refresh-token", r.Form.Get("token"))
require.Equal(t, "test-upstream-token", r.Form.Get("token"))
require.Equal(t, tt.wantTokenTypeHint, r.Form.Get("token_type_hint"))
username, password, hasBasicAuth := r.BasicAuth()
require.True(t, hasBasicAuth, "request should have had basic auth but did not")
require.Equal(t, "test-client-id", username)
@ -574,9 +616,10 @@ func TestProviderConfig(t *testing.T) {
p.RevocationURL = nil
}
err = p.RevokeRefreshToken(
err = p.RevokeToken(
context.Background(),
"test-initial-refresh-token",
"test-upstream-token",
tt.tokenType,
)
require.Equal(t, tt.wantNumRequests, numRequests,