Save an http.Client with each upstreamoidc.ProviderConfig object.
This allows the token exchange request to be performed with the correct TLS configuration. We go to a bit of extra work to make sure the `http.Client` object is cached between reconcile operations so that connection pooling works as expected. Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
c23c54f500
commit
4fe691de92
@ -70,15 +70,21 @@ type IDPCache interface {
|
|||||||
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
|
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
|
||||||
type lruValidatorCache struct{ cache *cache.Expiring }
|
type lruValidatorCache struct{ cache *cache.Expiring }
|
||||||
|
|
||||||
func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider {
|
type lruValidatorCacheEntry struct {
|
||||||
if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
|
provider *oidc.Provider
|
||||||
return result.(*oidc.Provider)
|
client *http.Client
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) {
|
func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) {
|
||||||
c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL)
|
if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
|
||||||
|
entry := result.(*lruValidatorCacheEntry)
|
||||||
|
return entry.provider, entry.client
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) {
|
||||||
|
c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} {
|
func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} {
|
||||||
@ -97,8 +103,8 @@ type controller struct {
|
|||||||
providers idpinformers.UpstreamOIDCProviderInformer
|
providers idpinformers.UpstreamOIDCProviderInformer
|
||||||
secrets corev1informers.SecretInformer
|
secrets corev1informers.SecretInformer
|
||||||
validatorCache interface {
|
validatorCache interface {
|
||||||
getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider
|
getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client)
|
||||||
putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider)
|
putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,6 +230,7 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
|
|||||||
|
|
||||||
// If everything is valid, update the result and set the condition to true.
|
// If everything is valid, update the result and set the condition to true.
|
||||||
result.Config.ClientID = string(clientID)
|
result.Config.ClientID = string(clientID)
|
||||||
|
result.Config.ClientSecret = string(clientSecret)
|
||||||
return &v1alpha1.Condition{
|
return &v1alpha1.Condition{
|
||||||
Type: typeClientCredsValid,
|
Type: typeClientCredsValid,
|
||||||
Status: v1alpha1.ConditionTrue,
|
Status: v1alpha1.ConditionTrue,
|
||||||
@ -234,8 +241,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
|
|||||||
|
|
||||||
// validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition.
|
// validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition.
|
||||||
func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
|
func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
|
||||||
// Get the provider (from cache if possible).
|
// Get the provider and HTTP Client from cache if possible.
|
||||||
discoveredProvider := c.validatorCache.getProvider(&upstream.Spec)
|
discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec)
|
||||||
|
|
||||||
// If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache.
|
// If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache.
|
||||||
if discoveredProvider == nil {
|
if discoveredProvider == nil {
|
||||||
@ -248,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
|
|||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
|
httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
|
||||||
|
|
||||||
discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer)
|
discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -261,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the cache with the newly discovered value.
|
// Update the cache with the newly discovered value.
|
||||||
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider)
|
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse out and validate the discovered authorize endpoint.
|
// Parse out and validate the discovered authorize endpoint.
|
||||||
@ -286,6 +293,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
|
|||||||
// If everything is valid, update the result and set the condition to true.
|
// If everything is valid, update the result and set the condition to true.
|
||||||
result.Config.Endpoint = discoveredProvider.Endpoint()
|
result.Config.Endpoint = discoveredProvider.Endpoint()
|
||||||
result.Provider = discoveredProvider
|
result.Provider = discoveredProvider
|
||||||
|
result.Client = httpClient
|
||||||
return &v1alpha1.Condition{
|
return &v1alpha1.Condition{
|
||||||
Type: typeOIDCDiscoverySucceeded,
|
Type: typeOIDCDiscoverySucceeded,
|
||||||
Status: v1alpha1.ConditionTrue,
|
Status: v1alpha1.ConditionTrue,
|
||||||
|
@ -20,8 +20,8 @@ import (
|
|||||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||||
)
|
)
|
||||||
|
|
||||||
func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||||
return &ProviderConfig{Config: config, Provider: provider}
|
return &ProviderConfig{Config: config, Provider: provider, Client: client}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProviderConfig holds the active configuration of an upstream OIDC provider.
|
// ProviderConfig holds the active configuration of an upstream OIDC provider.
|
||||||
@ -33,6 +33,7 @@ type ProviderConfig struct {
|
|||||||
Provider interface {
|
Provider interface {
|
||||||
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
||||||
}
|
}
|
||||||
|
Client *http.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProviderConfig) GetName() string {
|
func (p *ProviderConfig) GetName() string {
|
||||||
@ -61,7 +62,7 @@ func (p *ProviderConfig) GetGroupsClaim() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
||||||
tok, err := p.Config.Exchange(ctx, authcode, pkceCodeVerifier.Verifier())
|
tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oidctypes.Token{}, nil, err
|
return oidctypes.Token{}, nil, err
|
||||||
}
|
}
|
||||||
@ -74,7 +75,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e
|
|||||||
if !hasIDTok {
|
if !hasIDTok {
|
||||||
return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||||
}
|
}
|
||||||
validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(ctx, idTok)
|
validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ type handlerState struct {
|
|||||||
generatePKCE func() (pkce.Code, error)
|
generatePKCE func() (pkce.Code, error)
|
||||||
generateNonce func() (nonce.Nonce, error)
|
generateNonce func() (nonce.Nonce, error)
|
||||||
openURL func(string) error
|
openURL func(string) error
|
||||||
getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI
|
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -295,7 +295,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
|
|||||||
|
|
||||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||||
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||||
token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "")
|
token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -328,7 +328,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
|||||||
|
|
||||||
// Exchange the authorization code for access, ID, and refresh tokens and perform required
|
// Exchange the authorization code for access, ID, and refresh tokens and perform required
|
||||||
// validations on the returned ID token.
|
// validations on the returned ID token.
|
||||||
token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
|
token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
|
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
|
||||||
}
|
}
|
||||||
|
@ -238,7 +238,7 @@ func TestLogin(t *testing.T) {
|
|||||||
clientID: "test-client-id",
|
clientID: "test-client-id",
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||||
mock := mockUpstream(t)
|
mock := mockUpstream(t)
|
||||||
mock.EXPECT().
|
mock.EXPECT().
|
||||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||||
@ -277,7 +277,7 @@ func TestLogin(t *testing.T) {
|
|||||||
clientID: "test-client-id",
|
clientID: "test-client-id",
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||||
mock := mockUpstream(t)
|
mock := mockUpstream(t)
|
||||||
mock.EXPECT().
|
mock.EXPECT().
|
||||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||||
@ -522,7 +522,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
|||||||
wantHTTPStatus: http.StatusBadRequest,
|
wantHTTPStatus: http.StatusBadRequest,
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||||
mock := mockUpstream(t)
|
mock := mockUpstream(t)
|
||||||
mock.EXPECT().
|
mock.EXPECT().
|
||||||
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
||||||
@ -538,7 +538,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
|||||||
query: "state=test-state&code=valid",
|
query: "state=test-state&code=valid",
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||||
mock := mockUpstream(t)
|
mock := mockUpstream(t)
|
||||||
mock.EXPECT().
|
mock.EXPECT().
|
||||||
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
||||||
|
Loading…
Reference in New Issue
Block a user