Perform an upstream refresh during downstream refresh for OIDC upstreams

- If the upstream refresh fails, then fail the downstream refresh
- If the upstream refresh returns an ID token, then validate it (we
  use its claims in the future, but not in this commit)
- If the upstream refresh returns a new refresh token, then save it
  into the user's session in storage
- Pass the provider cache into the token handler so it can use the
  cached providers to perform upstream refreshes
- Handle unexpected errors in the token handler where the user's session
  does not contain the expected data. These should not be possible
  in practice unless someone is manually editing the storage, but
  handle them anyway just to be safe.
- Refactor to share the refresh code between the CLI and the token
  endpoint by moving it into the UpstreamOIDCIdentityProviderI
  interface, since the token endpoint needed it to be part of that
  interface anyway
This commit is contained in:
Ryan Richard 2021-10-13 12:31:20 -07:00
parent 1bd346cbeb
commit 79ca1d7fb0
11 changed files with 1195 additions and 126 deletions

View File

@ -200,6 +200,21 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PasswordCredentialsGran
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCredentialsGrantAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PasswordCredentialsGrantAndValidateTokens), arg0, arg1, arg2)
}
// PerformRefresh mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) PerformRefresh(arg0 context.Context, arg1 string) (*oauth2.Token, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PerformRefresh", arg0, arg1)
ret0, _ := ret[0].(*oauth2.Token)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PerformRefresh indicates an expected call of PerformRefresh.
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1)
}
// ValidateToken mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (*oidctypes.Token, error) {
m.ctrl.T.Helper()

View File

@ -210,8 +210,6 @@ func FositeOauth2Helper(
// The default is to support all prompt values from the spec.
// See https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// We'll make a best effort to support these by passing the value of this prompt param to the upstream IDP
// and rely on its implementation of this param.
AllowedPromptValues: nil,
// Use the fosite default to make it more likely that off the shelf OIDC clients can work with the supervisor.
@ -232,7 +230,7 @@ func FositeOauth2Helper(
compose.OpenIDConnectExplicitFactory,
compose.OpenIDConnectRefreshFactory,
compose.OAuth2PKCEFactory,
TokenExchangeFactory,
TokenExchangeFactory, // handle the "urn:ietf:params:oauth:grant-type:token-exchange" grant type
)
provider.(*fosite.Fosite).FormPostHTMLTemplate = formposthtml.Template()
return provider

View File

@ -63,6 +63,14 @@ type UpstreamOIDCIdentityProviderI interface {
redirectURI string,
) (*oidctypes.Token, error)
// PerformRefresh will call the provider's token endpoint to perform a refresh grant. The provider may or may not
// return a new ID or refresh token in the response. If it returns an ID token, then use ValidateRefresh to
// validate the ID token.
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
// tokens, or an error.
ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
}

View File

@ -130,6 +130,7 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
)
m.providerHandlers[(issuerHostWithPath + oidc.TokenEndpointPath)] = token.NewHandler(
m.upstreamIDPs,
oauthHelperWithKubeStorage,
)

View File

@ -5,17 +5,36 @@
package token
import (
"context"
"net/http"
"github.com/ory/fosite"
"github.com/ory/x/errorsx"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession"
)
var (
errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "There was an internal server error.",
HintField: "Required upstream data not found in session.",
CodeField: http.StatusInternalServerError,
}
errUpstreamRefreshError = &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "Error during upstream refresh.",
CodeField: http.StatusUnauthorized,
}
)
func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
@ -27,6 +46,20 @@ func NewHandler(
return nil
}
// Check if we are performing a refresh grant.
if accessRequest.GetGrantTypes().ExactOne("refresh_token") {
// The above call to NewAccessRequest has loaded the session from storage into the accessRequest variable.
// The session, requested scopes, and requested audience from the original authorize request was retrieved
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
// have already been granted on the accessRequest.
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
if err != nil {
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err)
return nil
}
}
accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest)
if err != nil {
plog.Info("token response error", oidc.FositeErrorForLog(err)...)
@ -39,3 +72,97 @@ func NewHandler(
return nil
})
}
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom
if customSessionData == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
providerName := customSessionData.ProviderName
providerUID := customSessionData.ProviderUID
if providerUID == "" || providerName == "" {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, customSessionData, providerCache)
case psession.ProviderTypeLDAP:
// upstream refresh not yet implemented for LDAP, so do nothing
case psession.ProviderTypeActiveDirectory:
// upstream refresh not yet implemented for AD, so do nothing
default:
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
return nil
}
func upstreamOIDCRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error {
if s.OIDC == nil || s.OIDC.UpstreamRefreshToken == "" {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
}
p, err := findOIDCProviderByNameAndValidateUID(s, providerCache)
if err != nil {
return err
}
plog.Debug("attempting upstream refresh request",
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
refreshedTokens, err := p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed using provider %q of type %q.",
s.ProviderName, s.ProviderType).WithWrap(err))
}
// Upstream refresh may or may not return a new ID token. From the spec:
// "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token."
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
_, hasIDTok := refreshedTokens.Extra("id_token").(string)
if hasIDTok {
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at
// least some providers do not include one, so we skip the nonce validation here (but not other validations).
_, err = p.ValidateToken(ctx, refreshedTokens, "")
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh returned an invalid ID token using provider %q of type %q.",
s.ProviderName, s.ProviderType).WithWrap(err))
}
} else {
plog.Debug("upstream refresh request did not return a new ID token",
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
}
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
// overwriting the old one.
if refreshedTokens.RefreshToken != "" {
plog.Debug("upstream refresh request did not return a new refresh token",
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
s.OIDC.UpstreamRefreshToken = refreshedTokens.RefreshToken
}
return nil
}
func findOIDCProviderByNameAndValidateUID(
s *psession.CustomSessionData,
providerCache oidc.UpstreamIdentityProvidersLister,
) (provider.UpstreamOIDCIdentityProviderI, error) {
for _, p := range providerCache.GetOIDCIdentityProviders() {
if p.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID {
return nil, errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Provider %q of type %q from upstream session data has changed its resource UID since authentication.",
s.ProviderName, s.ProviderType))
}
return p, nil
}
}
return nil, errorsx.WithStack(errUpstreamRefreshError.
WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType))
}

File diff suppressed because it is too large Load Diff

View File

@ -58,6 +58,21 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
Password string
}
// PerformRefreshArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
type PerformRefreshArgs struct {
Ctx context.Context
RefreshToken string
}
// ValidateTokenArgs is used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ValidateTokenFunc().
type ValidateTokenArgs struct {
Ctx context.Context
Tok *oauth2.Token
ExpectedIDTokenNonce nonce.Nonce
}
type TestUpstreamLDAPIdentityProvider struct {
Name string
ResourceUID types.UID
@ -107,10 +122,18 @@ type TestUpstreamOIDCIdentityProvider struct {
password string,
) (*oidctypes.Token, error)
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
exchangeAuthcodeAndValidateTokensCallCount int
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
passwordCredentialsGrantAndValidateTokensCallCount int
passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs
performRefreshCallCount int
performRefreshArgs []*PerformRefreshArgs
validateTokenCallCount int
validateTokenArgs []*ValidateTokenArgs
}
var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{}
@ -193,8 +216,51 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs
return u.exchangeAuthcodeAndValidateTokensArgs[call]
}
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(_ context.Context, _ *oauth2.Token, _ nonce.Nonce) (*oidctypes.Token, error) {
panic("implement me")
func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
}
u.performRefreshCallCount++
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx,
RefreshToken: refreshToken,
})
return u.PerformRefreshFunc(ctx, refreshToken)
}
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int {
return u.performRefreshCallCount
}
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs {
if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
}
return u.performRefreshArgs[call]
}
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
if u.validateTokenArgs == nil {
u.validateTokenArgs = make([]*ValidateTokenArgs, 0)
}
u.validateTokenCallCount++
u.validateTokenArgs = append(u.validateTokenArgs, &ValidateTokenArgs{
Ctx: ctx,
Tok: tok,
ExpectedIDTokenNonce: expectedIDTokenNonce,
})
return u.ValidateTokenFunc(ctx, tok, expectedIDTokenNonce)
}
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenCallCount() int {
return u.validateTokenCallCount
}
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenArgs(call int) *ValidateTokenArgs {
if u.validateTokenArgs == nil {
u.validateTokenArgs = make([]*ValidateTokenArgs, 0)
}
return u.validateTokenArgs[call]
}
type UpstreamIDPListerBuilder struct {
@ -316,6 +382,80 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
t *testing.T,
expectedPerformedByUpstreamName string,
expectedArgs *PerformRefreshArgs,
) {
t.Helper()
var actualArgs *PerformRefreshArgs
var actualNameOfUpstreamWhichMadeCall string
actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.performRefreshArgs[0]
}
}
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
"should have been exactly one call to PerformRefresh() by all OIDC upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"PerformRefresh() was called on the wrong OIDC upstream",
)
require.Equal(t, expectedArgs, actualArgs)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {
t.Helper()
actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.performRefreshCallCount
}
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
"expected exactly zero calls to PerformRefresh()",
)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken(
t *testing.T,
expectedPerformedByUpstreamName string,
expectedArgs *ValidateTokenArgs,
) {
t.Helper()
var actualArgs *ValidateTokenArgs
var actualNameOfUpstreamWhichMadeCall string
actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
callCountOnThisUpstream := upstreamOIDC.validateTokenCallCount
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
if callCountOnThisUpstream == 1 {
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
actualArgs = upstreamOIDC.validateTokenArgs[0]
}
}
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
"should have been exactly one call to ValidateToken() by all OIDC upstreams",
)
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
"ValidateToken() was called on the wrong OIDC upstream",
)
require.Equal(t, expectedArgs, actualArgs)
}
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *testing.T) {
t.Helper()
actualCallCountAcrossAllOIDCUpstreams := 0
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount
}
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
"expected exactly zero calls to ValidateToken()",
)
}
func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder {
return &UpstreamIDPListerBuilder{}
}
@ -329,11 +469,15 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
refreshToken *oidctypes.RefreshToken
usernameClaim string
groupsClaim string
refreshedTokens *oauth2.Token
validatedTokens *oidctypes.Token
authorizationURL url.URL
additionalAuthcodeParams map[string]string
allowPasswordGrant bool
authcodeExchangeErr error
passwordGrantErr error
performRefreshErr error
validateTokenErr error
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder {
@ -429,6 +573,26 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPasswordGrantError(err err
return u
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRefreshedTokens(tokens *oauth2.Token) *TestUpstreamOIDCIdentityProviderBuilder {
u.refreshedTokens = tokens
return u
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
u.performRefreshErr = err
return u
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder {
u.validatedTokens = tokens
return u
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
u.validateTokenErr = err
return u
}
func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdentityProvider {
return &TestUpstreamOIDCIdentityProvider{
Name: u.name,
@ -452,6 +616,18 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
}
return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}, RefreshToken: u.refreshToken}, nil
},
PerformRefreshFunc: func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
if u.performRefreshErr != nil {
return nil, u.performRefreshErr
}
return u.refreshedTokens, nil
},
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
if u.validateTokenErr != nil {
return nil, u.validateTokenErr
}
return u.validatedTokens, nil
},
}
}

View File

@ -47,6 +47,8 @@ type ProviderConfig struct {
}
}
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
func (p *ProviderConfig) GetResourceUID() types.UID {
return p.ResourceUID
}
@ -120,6 +122,14 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context,
return p.ValidateToken(ctx, tok, expectedIDTokenNonce)
}
func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
// Create a TokenSource without an access token, so it thinks that a refresh is immediately required.
// Then ask it for the tokens to cause it to perform the refresh and return the results.
return p.Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}).Token()
}
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response,
// if the provider offers the userinfo endpoint.
func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
@ -146,7 +156,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
}
maybeLogClaims("claims from ID token", p.Name, validatedClaims)
if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil {
if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims); err != nil {
return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err)
}
@ -167,7 +177,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
}, nil
}
func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error {
func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error {
idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string)
if len(idTokenSubject) == 0 {
return nil // defer to existing ID token validation

View File

@ -23,6 +23,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/mocks/mockkeyset"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
)
@ -288,6 +289,171 @@ func TestProviderConfig(t *testing.T) {
}
})
t.Run("PerformRefresh", func(t *testing.T) {
tests := []struct {
name string
returnIDTok string
returnAccessTok string
returnRefreshTok string
returnTokType string
returnExpiresIn string
tokenStatusCode int
wantErr string
wantToken *oauth2.Token
wantTokenExtras map[string]interface{}
}{
{
name: "success when the server returns all tokens in the refresh result",
returnIDTok: "test-id-token",
returnAccessTok: "test-access-token",
returnRefreshTok: "test-refresh-token",
returnTokType: "test-token-type",
returnExpiresIn: "42",
tokenStatusCode: http.StatusOK,
wantToken: &oauth2.Token{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
TokenType: "test-token-type",
Expiry: time.Now().Add(42 * time.Second),
},
wantTokenExtras: map[string]interface{}{
// the ID token only appears in the extras map
"id_token": "test-id-token",
// the library also repeats all the other keys/values returned by the server in the raw extras map
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"token_type": "test-token-type",
"expires_in": "42",
// the library also adds this zero-value even though the server did not return it
"expiry": "0001-01-01T00:00:00Z",
},
},
{
name: "success when the server does not return a new refresh token in the refresh result",
returnIDTok: "test-id-token",
returnAccessTok: "test-access-token",
returnRefreshTok: "",
returnTokType: "test-token-type",
returnExpiresIn: "42",
tokenStatusCode: http.StatusOK,
wantToken: &oauth2.Token{
AccessToken: "test-access-token",
// the library sets the original refresh token into the result, even though the server did not return that
RefreshToken: "test-initial-refresh-token",
TokenType: "test-token-type",
Expiry: time.Now().Add(42 * time.Second),
},
wantTokenExtras: map[string]interface{}{
// the ID token only appears in the extras map
"id_token": "test-id-token",
// the library also repeats all the other keys/values returned by the server in the raw extras map
"access_token": "test-access-token",
"token_type": "test-token-type",
"expires_in": "42",
// the library also adds this zero-value even though the server did not return it
"expiry": "0001-01-01T00:00:00Z",
},
},
{
name: "success when the server does not return a new ID token in the refresh result",
returnIDTok: "",
returnAccessTok: "test-access-token",
returnRefreshTok: "test-refresh-token",
returnTokType: "test-token-type",
returnExpiresIn: "42",
tokenStatusCode: http.StatusOK,
wantToken: &oauth2.Token{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
TokenType: "test-token-type",
Expiry: time.Now().Add(42 * time.Second),
},
wantTokenExtras: map[string]interface{}{
// the library also repeats all the other keys/values returned by the server in the raw extras map
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"token_type": "test-token-type",
"expires_in": "42",
// the library also adds this zero-value even though the server did not return it
"expiry": "0001-01-01T00:00:00Z",
},
},
{
name: "server returns an error on token refresh",
tokenStatusCode: http.StatusForbidden,
wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: fake error\n",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.NoError(t, r.ParseForm())
require.Equal(t, 4, len(r.Form))
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-client-secret", r.Form.Get("client_secret"))
require.Equal(t, "refresh_token", r.Form.Get("grant_type"))
require.Equal(t, "test-initial-refresh-token", r.Form.Get("refresh_token"))
if tt.tokenStatusCode != http.StatusOK {
http.Error(w, "fake error", tt.tokenStatusCode)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
ExpiresIn string `json:"expires_in,omitempty"`
}
response.IDToken = tt.returnIDTok
response.AccessToken = tt.returnAccessTok
response.RefreshToken = tt.returnRefreshTok
response.TokenType = tt.returnTokType
response.ExpiresIn = tt.returnExpiresIn
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}))
t.Cleanup(tokenServer.Close)
p := ProviderConfig{
Name: "test-name",
UsernameClaim: "test-username-claim",
GroupsClaim: "test-groups-claim",
Config: &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
Endpoint: oauth2.Endpoint{
AuthURL: "https://example.com",
TokenURL: tokenServer.URL,
AuthStyle: oauth2.AuthStyleInParams,
},
Scopes: []string{"scope1", "scope2"},
},
}
tok, err := p.PerformRefresh(
context.Background(),
"test-initial-refresh-token",
)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantToken.TokenType, tok.TokenType)
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
require.Equal(t, tt.wantToken.AccessToken, tok.AccessToken)
testutil.RequireTimeInDelta(t, tt.wantToken.Expiry, tok.Expiry, 5*time.Second)
for k, v := range tt.wantTokenExtras {
require.Equal(t, v, tok.Extra(k))
}
})
}
})
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
tests := []struct {
name string

View File

@ -808,9 +808,9 @@ func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidcty
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
h.logger.V(debugLogLevel).Info("Pinniped: Refreshing cached token.")
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
upstreamOIDCIdentityProvider := h.getProvider(h.oauth2Config, h.provider, h.httpClient)
refreshed, err := refreshSource.Token()
refreshed, err := upstreamOIDCIdentityProvider.PerformRefresh(ctx, refreshToken.Token)
if err != nil {
// Ignore errors during refresh, but return nil which will trigger the full login flow.
return nil, nil
@ -818,7 +818,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations).
return h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "")
return upstreamOIDCIdentityProvider.ValidateToken(ctx, refreshed, "")
}
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {

View File

@ -35,6 +35,7 @@ import (
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/testlogger"
"go.pinniped.dev/internal/upstreamoidc"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
@ -404,11 +405,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(&testToken, nil)
mock.EXPECT().
PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).
DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
// Call the real production code to perform a refresh.
return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken)
})
return mock
}
@ -445,11 +452,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(nil, fmt.Errorf("some validation error"))
mock.EXPECT().
PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token").
DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
// Call the real production code to perform a refresh.
return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken)
})
return mock
}
@ -1522,11 +1535,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
})
h.cache = cache
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(&testToken, nil)
mock.EXPECT().
PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).
DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
// Call the real production code to perform a refresh.
return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken)
})
return mock
}