diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go index 4955646b..7faa4d9c 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go @@ -70,15 +70,21 @@ type IDPCache interface { // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. type lruValidatorCache struct{ cache *cache.Expiring } -func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider { - if result, ok := c.cache.Get(c.cacheKey(spec)); ok { - return result.(*oidc.Provider) - } - return nil +type lruValidatorCacheEntry struct { + provider *oidc.Provider + client *http.Client } -func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) { - c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL) +func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) { + if result, ok := c.cache.Get(c.cacheKey(spec)); ok { + entry := result.(*lruValidatorCacheEntry) + return entry.provider, entry.client + } + return nil, nil +} + +func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) { + c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL) } func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} { @@ -97,8 +103,8 @@ type controller struct { providers idpinformers.UpstreamOIDCProviderInformer secrets corev1informers.SecretInformer validatorCache interface { - getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider - putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) + getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) + putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client) } } @@ -224,6 +230,7 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res // If everything is valid, update the result and set the condition to true. result.Config.ClientID = string(clientID) + result.Config.ClientSecret = string(clientSecret) return &v1alpha1.Condition{ Type: typeClientCredsValid, Status: v1alpha1.ConditionTrue, @@ -234,8 +241,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { - // Get the provider (from cache if possible). - discoveredProvider := c.validatorCache.getProvider(&upstream.Spec) + // Get the provider and HTTP Client from cache if possible. + discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec) // If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache. if discoveredProvider == nil { @@ -248,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst Message: err.Error(), } } - httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} + httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer) if err != nil { @@ -261,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst } // Update the cache with the newly discovered value. - c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) + c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient) } // Parse out and validate the discovered authorize endpoint. @@ -286,6 +293,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst // If everything is valid, update the result and set the condition to true. result.Config.Endpoint = discoveredProvider.Endpoint() result.Provider = discoveredProvider + result.Client = httpClient return &v1alpha1.Condition{ Type: typeOIDCDiscoverySucceeded, Status: v1alpha1.ConditionTrue, diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 2957e9e3..4af7efdb 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -20,8 +20,8 @@ import ( "go.pinniped.dev/pkg/oidcclient/pkce" ) -func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { - return &ProviderConfig{Config: config, Provider: provider} +func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { + return &ProviderConfig{Config: config, Provider: provider, Client: client} } // ProviderConfig holds the active configuration of an upstream OIDC provider. @@ -33,6 +33,7 @@ type ProviderConfig struct { Provider interface { Verifier(*oidc.Config) *oidc.IDTokenVerifier } + Client *http.Client } func (p *ProviderConfig) GetName() string { @@ -61,7 +62,7 @@ func (p *ProviderConfig) GetGroupsClaim() string { } func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - tok, err := p.Config.Exchange(ctx, authcode, pkceCodeVerifier.Verifier()) + tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier()) if err != nil { return oidctypes.Token{}, nil, err } @@ -74,7 +75,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e if !hasIDTok { return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") } - validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(ctx, idTok) + validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok) if err != nil { return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index fbbe23a9..2b21e080 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -64,7 +64,7 @@ type handlerState struct { generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error - getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI + getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI callbacks chan callbackResult } @@ -295,7 +295,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "") + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") if err != nil { return nil, err } @@ -328,7 +328,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 280dfd0a..374d90e3 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -238,7 +238,7 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). @@ -277,7 +277,7 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). @@ -522,7 +522,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantHTTPStatus: http.StatusBadRequest, opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). @@ -538,7 +538,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { query: "state=test-state&code=valid", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).