Perform an upstream refresh during downstream refresh for OIDC upstreams
- If the upstream refresh fails, then fail the downstream refresh - If the upstream refresh returns an ID token, then validate it (we use its claims in the future, but not in this commit) - If the upstream refresh returns a new refresh token, then save it into the user's session in storage - Pass the provider cache into the token handler so it can use the cached providers to perform upstream refreshes - Handle unexpected errors in the token handler where the user's session does not contain the expected data. These should not be possible in practice unless someone is manually editing the storage, but handle them anyway just to be safe. - Refactor to share the refresh code between the CLI and the token endpoint by moving it into the UpstreamOIDCIdentityProviderI interface, since the token endpoint needed it to be part of that interface anyway
This commit is contained in:
parent
1bd346cbeb
commit
79ca1d7fb0
@ -200,6 +200,21 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PasswordCredentialsGran
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCredentialsGrantAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PasswordCredentialsGrantAndValidateTokens), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// PerformRefresh mocks base method.
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) PerformRefresh(arg0 context.Context, arg1 string) (*oauth2.Token, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PerformRefresh", arg0, arg1)
|
||||
ret0, _ := ret[0].(*oauth2.Token)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PerformRefresh indicates an expected call of PerformRefresh.
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PerformRefresh", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).PerformRefresh), arg0, arg1)
|
||||
}
|
||||
|
||||
// ValidateToken mocks base method.
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (*oidctypes.Token, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -210,8 +210,6 @@ func FositeOauth2Helper(
|
||||
|
||||
// The default is to support all prompt values from the spec.
|
||||
// See https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
|
||||
// We'll make a best effort to support these by passing the value of this prompt param to the upstream IDP
|
||||
// and rely on its implementation of this param.
|
||||
AllowedPromptValues: nil,
|
||||
|
||||
// Use the fosite default to make it more likely that off the shelf OIDC clients can work with the supervisor.
|
||||
@ -232,7 +230,7 @@ func FositeOauth2Helper(
|
||||
compose.OpenIDConnectExplicitFactory,
|
||||
compose.OpenIDConnectRefreshFactory,
|
||||
compose.OAuth2PKCEFactory,
|
||||
TokenExchangeFactory,
|
||||
TokenExchangeFactory, // handle the "urn:ietf:params:oauth:grant-type:token-exchange" grant type
|
||||
)
|
||||
provider.(*fosite.Fosite).FormPostHTMLTemplate = formposthtml.Template()
|
||||
return provider
|
||||
|
@ -63,6 +63,14 @@ type UpstreamOIDCIdentityProviderI interface {
|
||||
redirectURI string,
|
||||
) (*oidctypes.Token, error)
|
||||
|
||||
// PerformRefresh will call the provider's token endpoint to perform a refresh grant. The provider may or may not
|
||||
// return a new ID or refresh token in the response. If it returns an ID token, then use ValidateRefresh to
|
||||
// validate the ID token.
|
||||
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
|
||||
|
||||
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response
|
||||
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
|
||||
// tokens, or an error.
|
||||
ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
|
||||
}
|
||||
|
||||
|
@ -130,6 +130,7 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
|
||||
)
|
||||
|
||||
m.providerHandlers[(issuerHostWithPath + oidc.TokenEndpointPath)] = token.NewHandler(
|
||||
m.upstreamIDPs,
|
||||
oauthHelperWithKubeStorage,
|
||||
)
|
||||
|
||||
|
@ -5,17 +5,36 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/x/errorsx"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/psession"
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{
|
||||
ErrorField: "error",
|
||||
DescriptionField: "There was an internal server error.",
|
||||
HintField: "Required upstream data not found in session.",
|
||||
CodeField: http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
errUpstreamRefreshError = &fosite.RFC6749Error{
|
||||
ErrorField: "error",
|
||||
DescriptionField: "Error during upstream refresh.",
|
||||
CodeField: http.StatusUnauthorized,
|
||||
}
|
||||
)
|
||||
|
||||
func NewHandler(
|
||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
) http.Handler {
|
||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
@ -27,6 +46,20 @@ func NewHandler(
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we are performing a refresh grant.
|
||||
if accessRequest.GetGrantTypes().ExactOne("refresh_token") {
|
||||
// The above call to NewAccessRequest has loaded the session from storage into the accessRequest variable.
|
||||
// The session, requested scopes, and requested audience from the original authorize request was retrieved
|
||||
// from the Kube storage layer and added to the accessRequest. Additionally, the audience and scopes may
|
||||
// have already been granted on the accessRequest.
|
||||
err = upstreamRefresh(r.Context(), accessRequest, idpLister)
|
||||
if err != nil {
|
||||
plog.Info("upstream refresh error", oidc.FositeErrorForLog(err)...)
|
||||
oauthHelper.WriteAccessError(w, accessRequest, err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest)
|
||||
if err != nil {
|
||||
plog.Info("token response error", oidc.FositeErrorForLog(err)...)
|
||||
@ -39,3 +72,97 @@ func NewHandler(
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
session := accessRequest.GetSession().(*psession.PinnipedSession)
|
||||
customSessionData := session.Custom
|
||||
if customSessionData == nil {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
providerName := customSessionData.ProviderName
|
||||
providerUID := customSessionData.ProviderUID
|
||||
if providerUID == "" || providerName == "" {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
|
||||
switch customSessionData.ProviderType {
|
||||
case psession.ProviderTypeOIDC:
|
||||
return upstreamOIDCRefresh(ctx, customSessionData, providerCache)
|
||||
case psession.ProviderTypeLDAP:
|
||||
// upstream refresh not yet implemented for LDAP, so do nothing
|
||||
case psession.ProviderTypeActiveDirectory:
|
||||
// upstream refresh not yet implemented for AD, so do nothing
|
||||
default:
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func upstreamOIDCRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
if s.OIDC == nil || s.OIDC.UpstreamRefreshToken == "" {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
|
||||
p, err := findOIDCProviderByNameAndValidateUID(s, providerCache)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
plog.Debug("attempting upstream refresh request",
|
||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||
|
||||
refreshedTokens, err := p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed using provider %q of type %q.",
|
||||
s.ProviderName, s.ProviderType).WithWrap(err))
|
||||
}
|
||||
|
||||
// Upstream refresh may or may not return a new ID token. From the spec:
|
||||
// "the response body is the Token Response of Section 3.1.3.3 except that it might not contain an id_token."
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse
|
||||
_, hasIDTok := refreshedTokens.Extra("id_token").(string)
|
||||
if hasIDTok {
|
||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at
|
||||
// least some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||
_, err = p.ValidateToken(ctx, refreshedTokens, "")
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh returned an invalid ID token using provider %q of type %q.",
|
||||
s.ProviderName, s.ProviderType).WithWrap(err))
|
||||
}
|
||||
} else {
|
||||
plog.Debug("upstream refresh request did not return a new ID token",
|
||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||
}
|
||||
|
||||
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
|
||||
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
|
||||
// overwriting the old one.
|
||||
if refreshedTokens.RefreshToken != "" {
|
||||
plog.Debug("upstream refresh request did not return a new refresh token",
|
||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||
s.OIDC.UpstreamRefreshToken = refreshedTokens.RefreshToken
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findOIDCProviderByNameAndValidateUID(
|
||||
s *psession.CustomSessionData,
|
||||
providerCache oidc.UpstreamIdentityProvidersLister,
|
||||
) (provider.UpstreamOIDCIdentityProviderI, error) {
|
||||
for _, p := range providerCache.GetOIDCIdentityProviders() {
|
||||
if p.GetName() == s.ProviderName {
|
||||
if p.GetResourceUID() != s.ProviderUID {
|
||||
return nil, errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Provider %q of type %q from upstream session data has changed its resource UID since authentication.",
|
||||
s.ProviderName, s.ProviderType))
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, errorsx.WithStack(errUpstreamRefreshError.
|
||||
WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -58,6 +58,21 @@ type PasswordCredentialsGrantAndValidateTokensArgs struct {
|
||||
Password string
|
||||
}
|
||||
|
||||
// PerformRefreshArgs is used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.PerformRefreshFunc().
|
||||
type PerformRefreshArgs struct {
|
||||
Ctx context.Context
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// ValidateTokenArgs is used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.ValidateTokenFunc().
|
||||
type ValidateTokenArgs struct {
|
||||
Ctx context.Context
|
||||
Tok *oauth2.Token
|
||||
ExpectedIDTokenNonce nonce.Nonce
|
||||
}
|
||||
|
||||
type TestUpstreamLDAPIdentityProvider struct {
|
||||
Name string
|
||||
ResourceUID types.UID
|
||||
@ -107,10 +122,18 @@ type TestUpstreamOIDCIdentityProvider struct {
|
||||
password string,
|
||||
) (*oidctypes.Token, error)
|
||||
|
||||
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
|
||||
|
||||
ValidateTokenFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
|
||||
|
||||
exchangeAuthcodeAndValidateTokensCallCount int
|
||||
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
|
||||
passwordCredentialsGrantAndValidateTokensCallCount int
|
||||
passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs
|
||||
performRefreshCallCount int
|
||||
performRefreshArgs []*PerformRefreshArgs
|
||||
validateTokenCallCount int
|
||||
validateTokenArgs []*ValidateTokenArgs
|
||||
}
|
||||
|
||||
var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{}
|
||||
@ -193,8 +216,51 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs
|
||||
return u.exchangeAuthcodeAndValidateTokensArgs[call]
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(_ context.Context, _ *oauth2.Token, _ nonce.Nonce) (*oidctypes.Token, error) {
|
||||
panic("implement me")
|
||||
func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
|
||||
if u.performRefreshArgs == nil {
|
||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||
}
|
||||
u.performRefreshCallCount++
|
||||
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
|
||||
Ctx: ctx,
|
||||
RefreshToken: refreshToken,
|
||||
})
|
||||
return u.PerformRefreshFunc(ctx, refreshToken)
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int {
|
||||
return u.performRefreshCallCount
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs {
|
||||
if u.performRefreshArgs == nil {
|
||||
u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
|
||||
}
|
||||
return u.performRefreshArgs[call]
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
||||
if u.validateTokenArgs == nil {
|
||||
u.validateTokenArgs = make([]*ValidateTokenArgs, 0)
|
||||
}
|
||||
u.validateTokenCallCount++
|
||||
u.validateTokenArgs = append(u.validateTokenArgs, &ValidateTokenArgs{
|
||||
Ctx: ctx,
|
||||
Tok: tok,
|
||||
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
||||
})
|
||||
return u.ValidateTokenFunc(ctx, tok, expectedIDTokenNonce)
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenCallCount() int {
|
||||
return u.validateTokenCallCount
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenArgs(call int) *ValidateTokenArgs {
|
||||
if u.validateTokenArgs == nil {
|
||||
u.validateTokenArgs = make([]*ValidateTokenArgs, 0)
|
||||
}
|
||||
return u.validateTokenArgs[call]
|
||||
}
|
||||
|
||||
type UpstreamIDPListerBuilder struct {
|
||||
@ -316,6 +382,80 @@ func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndV
|
||||
)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh(
|
||||
t *testing.T,
|
||||
expectedPerformedByUpstreamName string,
|
||||
expectedArgs *PerformRefreshArgs,
|
||||
) {
|
||||
t.Helper()
|
||||
var actualArgs *PerformRefreshArgs
|
||||
var actualNameOfUpstreamWhichMadeCall string
|
||||
actualCallCountAcrossAllOIDCUpstreams := 0
|
||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount
|
||||
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
||||
actualArgs = upstreamOIDC.performRefreshArgs[0]
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
|
||||
"should have been exactly one call to PerformRefresh() by all OIDC upstreams",
|
||||
)
|
||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||
"PerformRefresh() was called on the wrong OIDC upstream",
|
||||
)
|
||||
require.Equal(t, expectedArgs, actualArgs)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) {
|
||||
t.Helper()
|
||||
actualCallCountAcrossAllOIDCUpstreams := 0
|
||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.performRefreshCallCount
|
||||
}
|
||||
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
|
||||
"expected exactly zero calls to PerformRefresh()",
|
||||
)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken(
|
||||
t *testing.T,
|
||||
expectedPerformedByUpstreamName string,
|
||||
expectedArgs *ValidateTokenArgs,
|
||||
) {
|
||||
t.Helper()
|
||||
var actualArgs *ValidateTokenArgs
|
||||
var actualNameOfUpstreamWhichMadeCall string
|
||||
actualCallCountAcrossAllOIDCUpstreams := 0
|
||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||
callCountOnThisUpstream := upstreamOIDC.validateTokenCallCount
|
||||
actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream
|
||||
if callCountOnThisUpstream == 1 {
|
||||
actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name
|
||||
actualArgs = upstreamOIDC.validateTokenArgs[0]
|
||||
}
|
||||
}
|
||||
require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams,
|
||||
"should have been exactly one call to ValidateToken() by all OIDC upstreams",
|
||||
)
|
||||
require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall,
|
||||
"ValidateToken() was called on the wrong OIDC upstream",
|
||||
)
|
||||
require.Equal(t, expectedArgs, actualArgs)
|
||||
}
|
||||
|
||||
func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *testing.T) {
|
||||
t.Helper()
|
||||
actualCallCountAcrossAllOIDCUpstreams := 0
|
||||
for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders {
|
||||
actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenCallCount
|
||||
}
|
||||
require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams,
|
||||
"expected exactly zero calls to ValidateToken()",
|
||||
)
|
||||
}
|
||||
|
||||
func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder {
|
||||
return &UpstreamIDPListerBuilder{}
|
||||
}
|
||||
@ -329,11 +469,15 @@ type TestUpstreamOIDCIdentityProviderBuilder struct {
|
||||
refreshToken *oidctypes.RefreshToken
|
||||
usernameClaim string
|
||||
groupsClaim string
|
||||
refreshedTokens *oauth2.Token
|
||||
validatedTokens *oidctypes.Token
|
||||
authorizationURL url.URL
|
||||
additionalAuthcodeParams map[string]string
|
||||
allowPasswordGrant bool
|
||||
authcodeExchangeErr error
|
||||
passwordGrantErr error
|
||||
performRefreshErr error
|
||||
validateTokenErr error
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder {
|
||||
@ -429,6 +573,26 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPasswordGrantError(err err
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRefreshedTokens(tokens *oauth2.Token) *TestUpstreamOIDCIdentityProviderBuilder {
|
||||
u.refreshedTokens = tokens
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
|
||||
u.performRefreshErr = err
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder {
|
||||
u.validatedTokens = tokens
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder {
|
||||
u.validateTokenErr = err
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdentityProvider {
|
||||
return &TestUpstreamOIDCIdentityProvider{
|
||||
Name: u.name,
|
||||
@ -452,6 +616,18 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
|
||||
}
|
||||
return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}, RefreshToken: u.refreshToken}, nil
|
||||
},
|
||||
PerformRefreshFunc: func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
|
||||
if u.performRefreshErr != nil {
|
||||
return nil, u.performRefreshErr
|
||||
}
|
||||
return u.refreshedTokens, nil
|
||||
},
|
||||
ValidateTokenFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
||||
if u.validateTokenErr != nil {
|
||||
return nil, u.validateTokenErr
|
||||
}
|
||||
return u.validatedTokens, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,6 +47,8 @@ type ProviderConfig struct {
|
||||
}
|
||||
}
|
||||
|
||||
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
|
||||
|
||||
func (p *ProviderConfig) GetResourceUID() types.UID {
|
||||
return p.ResourceUID
|
||||
}
|
||||
@ -120,6 +122,14 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context,
|
||||
return p.ValidateToken(ctx, tok, expectedIDTokenNonce)
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
|
||||
// Create a TokenSource without an access token, so it thinks that a refresh is immediately required.
|
||||
// Then ask it for the tokens to cause it to perform the refresh and return the results.
|
||||
return p.Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}).Token()
|
||||
}
|
||||
|
||||
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response,
|
||||
// if the provider offers the userinfo endpoint.
|
||||
func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
if !hasIDTok {
|
||||
@ -146,7 +156,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
|
||||
}
|
||||
maybeLogClaims("claims from ID token", p.Name, validatedClaims)
|
||||
|
||||
if err := p.fetchUserInfo(ctx, tok, validatedClaims); err != nil {
|
||||
if err := p.maybeFetchUserInfoAndMergeClaims(ctx, tok, validatedClaims); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusInternalServerError, "could not fetch user info claims", err)
|
||||
}
|
||||
|
||||
@ -167,7 +177,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error {
|
||||
func (p *ProviderConfig) maybeFetchUserInfoAndMergeClaims(ctx context.Context, tok *oauth2.Token, claims map[string]interface{}) error {
|
||||
idTokenSubject, _ := claims[oidc.IDTokenSubjectClaim].(string)
|
||||
if len(idTokenSubject) == 0 {
|
||||
return nil // defer to existing ID token validation
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/mocks/mockkeyset"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
@ -288,6 +289,171 @@ func TestProviderConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PerformRefresh", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
returnIDTok string
|
||||
returnAccessTok string
|
||||
returnRefreshTok string
|
||||
returnTokType string
|
||||
returnExpiresIn string
|
||||
tokenStatusCode int
|
||||
|
||||
wantErr string
|
||||
wantToken *oauth2.Token
|
||||
wantTokenExtras map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "success when the server returns all tokens in the refresh result",
|
||||
returnIDTok: "test-id-token",
|
||||
returnAccessTok: "test-access-token",
|
||||
returnRefreshTok: "test-refresh-token",
|
||||
returnTokType: "test-token-type",
|
||||
returnExpiresIn: "42",
|
||||
tokenStatusCode: http.StatusOK,
|
||||
wantToken: &oauth2.Token{
|
||||
AccessToken: "test-access-token",
|
||||
RefreshToken: "test-refresh-token",
|
||||
TokenType: "test-token-type",
|
||||
Expiry: time.Now().Add(42 * time.Second),
|
||||
},
|
||||
wantTokenExtras: map[string]interface{}{
|
||||
// the ID token only appears in the extras map
|
||||
"id_token": "test-id-token",
|
||||
// the library also repeats all the other keys/values returned by the server in the raw extras map
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"token_type": "test-token-type",
|
||||
"expires_in": "42",
|
||||
// the library also adds this zero-value even though the server did not return it
|
||||
"expiry": "0001-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success when the server does not return a new refresh token in the refresh result",
|
||||
returnIDTok: "test-id-token",
|
||||
returnAccessTok: "test-access-token",
|
||||
returnRefreshTok: "",
|
||||
returnTokType: "test-token-type",
|
||||
returnExpiresIn: "42",
|
||||
tokenStatusCode: http.StatusOK,
|
||||
wantToken: &oauth2.Token{
|
||||
AccessToken: "test-access-token",
|
||||
// the library sets the original refresh token into the result, even though the server did not return that
|
||||
RefreshToken: "test-initial-refresh-token",
|
||||
TokenType: "test-token-type",
|
||||
Expiry: time.Now().Add(42 * time.Second),
|
||||
},
|
||||
wantTokenExtras: map[string]interface{}{
|
||||
// the ID token only appears in the extras map
|
||||
"id_token": "test-id-token",
|
||||
// the library also repeats all the other keys/values returned by the server in the raw extras map
|
||||
"access_token": "test-access-token",
|
||||
"token_type": "test-token-type",
|
||||
"expires_in": "42",
|
||||
// the library also adds this zero-value even though the server did not return it
|
||||
"expiry": "0001-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success when the server does not return a new ID token in the refresh result",
|
||||
returnIDTok: "",
|
||||
returnAccessTok: "test-access-token",
|
||||
returnRefreshTok: "test-refresh-token",
|
||||
returnTokType: "test-token-type",
|
||||
returnExpiresIn: "42",
|
||||
tokenStatusCode: http.StatusOK,
|
||||
wantToken: &oauth2.Token{
|
||||
AccessToken: "test-access-token",
|
||||
RefreshToken: "test-refresh-token",
|
||||
TokenType: "test-token-type",
|
||||
Expiry: time.Now().Add(42 * time.Second),
|
||||
},
|
||||
wantTokenExtras: map[string]interface{}{
|
||||
// the library also repeats all the other keys/values returned by the server in the raw extras map
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"token_type": "test-token-type",
|
||||
"expires_in": "42",
|
||||
// the library also adds this zero-value even though the server did not return it
|
||||
"expiry": "0001-01-01T00:00:00Z",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "server returns an error on token refresh",
|
||||
tokenStatusCode: http.StatusForbidden,
|
||||
wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: fake error\n",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.NoError(t, r.ParseForm())
|
||||
require.Equal(t, 4, len(r.Form))
|
||||
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
||||
require.Equal(t, "test-client-secret", r.Form.Get("client_secret"))
|
||||
require.Equal(t, "refresh_token", r.Form.Get("grant_type"))
|
||||
require.Equal(t, "test-initial-refresh-token", r.Form.Get("refresh_token"))
|
||||
if tt.tokenStatusCode != http.StatusOK {
|
||||
http.Error(w, "fake error", tt.tokenStatusCode)
|
||||
return
|
||||
}
|
||||
var response struct {
|
||||
oauth2.Token
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn string `json:"expires_in,omitempty"`
|
||||
}
|
||||
response.IDToken = tt.returnIDTok
|
||||
response.AccessToken = tt.returnAccessTok
|
||||
response.RefreshToken = tt.returnRefreshTok
|
||||
response.TokenType = tt.returnTokType
|
||||
response.ExpiresIn = tt.returnExpiresIn
|
||||
w.Header().Set("content-type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||
}))
|
||||
t.Cleanup(tokenServer.Close)
|
||||
|
||||
p := ProviderConfig{
|
||||
Name: "test-name",
|
||||
UsernameClaim: "test-username-claim",
|
||||
GroupsClaim: "test-groups-claim",
|
||||
Config: &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://example.com",
|
||||
TokenURL: tokenServer.URL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
Scopes: []string{"scope1", "scope2"},
|
||||
},
|
||||
}
|
||||
|
||||
tok, err := p.PerformRefresh(
|
||||
context.Background(),
|
||||
"test-initial-refresh-token",
|
||||
)
|
||||
|
||||
if tt.wantErr != "" {
|
||||
require.EqualError(t, err, tt.wantErr)
|
||||
require.Nil(t, tok)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantToken.TokenType, tok.TokenType)
|
||||
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
|
||||
require.Equal(t, tt.wantToken.AccessToken, tok.AccessToken)
|
||||
testutil.RequireTimeInDelta(t, tt.wantToken.Expiry, tok.Expiry, 5*time.Second)
|
||||
for k, v := range tt.wantTokenExtras {
|
||||
require.Equal(t, v, tok.Extra(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -808,9 +808,9 @@ func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidcty
|
||||
|
||||
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
|
||||
h.logger.V(debugLogLevel).Info("Pinniped: Refreshing cached token.")
|
||||
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
|
||||
upstreamOIDCIdentityProvider := h.getProvider(h.oauth2Config, h.provider, h.httpClient)
|
||||
|
||||
refreshed, err := refreshSource.Token()
|
||||
refreshed, err := upstreamOIDCIdentityProvider.PerformRefresh(ctx, refreshToken.Token)
|
||||
if err != nil {
|
||||
// Ignore errors during refresh, but return nil which will trigger the full login flow.
|
||||
return nil, nil
|
||||
@ -818,7 +818,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
|
||||
|
||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||
return h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "")
|
||||
return upstreamOIDCIdentityProvider.ValidateToken(ctx, refreshed, "")
|
||||
}
|
||||
|
||||
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
||||
|
@ -35,6 +35,7 @@ import (
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/internal/testutil/testlogger"
|
||||
"go.pinniped.dev/internal/upstreamoidc"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
@ -404,11 +405,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(&testToken, nil)
|
||||
mock.EXPECT().
|
||||
PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).
|
||||
DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
|
||||
// Call the real production code to perform a refresh.
|
||||
return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken)
|
||||
})
|
||||
return mock
|
||||
}
|
||||
|
||||
@ -445,11 +452,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(nil, fmt.Errorf("some validation error"))
|
||||
mock.EXPECT().
|
||||
PerformRefresh(gomock.Any(), "test-refresh-token-returning-invalid-id-token").
|
||||
DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
|
||||
// Call the real production code to perform a refresh.
|
||||
return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken)
|
||||
})
|
||||
return mock
|
||||
}
|
||||
|
||||
@ -1522,11 +1535,17 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(&testToken, nil)
|
||||
mock.EXPECT().
|
||||
PerformRefresh(gomock.Any(), testToken.RefreshToken.Token).
|
||||
DoAndReturn(func(ctx context.Context, refreshToken string) (*oauth2.Token, error) {
|
||||
// Call the real production code to perform a refresh.
|
||||
return upstreamoidc.New(config, provider, client).PerformRefresh(ctx, refreshToken)
|
||||
})
|
||||
return mock
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user