Save an http.Client with each upstreamoidc.ProviderConfig object.

This allows the token exchange request to be performed with the correct TLS configuration.

We go to a bit of extra work to make sure the `http.Client` object is cached between reconcile operations so that connection pooling works as expected.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-12-02 10:27:20 -06:00
parent c23c54f500
commit 4fe691de92
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
4 changed files with 33 additions and 24 deletions

View File

@ -70,15 +70,21 @@ type IDPCache interface {
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
type lruValidatorCache struct{ cache *cache.Expiring } type lruValidatorCache struct{ cache *cache.Expiring }
func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider { type lruValidatorCacheEntry struct {
if result, ok := c.cache.Get(c.cacheKey(spec)); ok { provider *oidc.Provider
return result.(*oidc.Provider) client *http.Client
}
return nil
} }
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) { func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) {
c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL) if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
entry := result.(*lruValidatorCacheEntry)
return entry.provider, entry.client
}
return nil, nil
}
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) {
c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL)
} }
func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} { func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} {
@ -97,8 +103,8 @@ type controller struct {
providers idpinformers.UpstreamOIDCProviderInformer providers idpinformers.UpstreamOIDCProviderInformer
secrets corev1informers.SecretInformer secrets corev1informers.SecretInformer
validatorCache interface { validatorCache interface {
getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client)
putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client)
} }
} }
@ -224,6 +230,7 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
// If everything is valid, update the result and set the condition to true. // If everything is valid, update the result and set the condition to true.
result.Config.ClientID = string(clientID) result.Config.ClientID = string(clientID)
result.Config.ClientSecret = string(clientSecret)
return &v1alpha1.Condition{ return &v1alpha1.Condition{
Type: typeClientCredsValid, Type: typeClientCredsValid,
Status: v1alpha1.ConditionTrue, Status: v1alpha1.ConditionTrue,
@ -234,8 +241,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
// validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition.
func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
// Get the provider (from cache if possible). // Get the provider and HTTP Client from cache if possible.
discoveredProvider := c.validatorCache.getProvider(&upstream.Spec) discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec)
// If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache. // If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache.
if discoveredProvider == nil { if discoveredProvider == nil {
@ -248,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
Message: err.Error(), Message: err.Error(),
} }
} }
httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer) discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer)
if err != nil { if err != nil {
@ -261,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
} }
// Update the cache with the newly discovered value. // Update the cache with the newly discovered value.
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient)
} }
// Parse out and validate the discovered authorize endpoint. // Parse out and validate the discovered authorize endpoint.
@ -286,6 +293,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
// If everything is valid, update the result and set the condition to true. // If everything is valid, update the result and set the condition to true.
result.Config.Endpoint = discoveredProvider.Endpoint() result.Config.Endpoint = discoveredProvider.Endpoint()
result.Provider = discoveredProvider result.Provider = discoveredProvider
result.Client = httpClient
return &v1alpha1.Condition{ return &v1alpha1.Condition{
Type: typeOIDCDiscoverySucceeded, Type: typeOIDCDiscoverySucceeded,
Status: v1alpha1.ConditionTrue, Status: v1alpha1.ConditionTrue,

View File

@ -20,8 +20,8 @@ import (
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
return &ProviderConfig{Config: config, Provider: provider} return &ProviderConfig{Config: config, Provider: provider, Client: client}
} }
// ProviderConfig holds the active configuration of an upstream OIDC provider. // ProviderConfig holds the active configuration of an upstream OIDC provider.
@ -33,6 +33,7 @@ type ProviderConfig struct {
Provider interface { Provider interface {
Verifier(*oidc.Config) *oidc.IDTokenVerifier Verifier(*oidc.Config) *oidc.IDTokenVerifier
} }
Client *http.Client
} }
func (p *ProviderConfig) GetName() string { func (p *ProviderConfig) GetName() string {
@ -61,7 +62,7 @@ func (p *ProviderConfig) GetGroupsClaim() string {
} }
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
tok, err := p.Config.Exchange(ctx, authcode, pkceCodeVerifier.Verifier()) tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier())
if err != nil { if err != nil {
return oidctypes.Token{}, nil, err return oidctypes.Token{}, nil, err
} }
@ -74,7 +75,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
if !hasIDTok { if !hasIDTok {
return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
} }
validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(ctx, idTok) validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok)
if err != nil { if err != nil {
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
} }

View File

@ -64,7 +64,7 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -295,7 +295,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations). // some providers do not include one, so we skip the nonce validation here (but not other validations).
token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "") token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -328,7 +328,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
// Exchange the authorization code for access, ID, and refresh tokens and perform required // Exchange the authorization code for access, ID, and refresh tokens and perform required
// validations on the returned ID token. // validations on the returned ID token.
token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }

View File

@ -238,7 +238,7 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
@ -277,7 +277,7 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
@ -522,7 +522,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
@ -538,7 +538,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
query: "state=test-state&code=valid", query: "state=test-state&code=valid",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).