Synchronize the OIDCProvider secrets cache

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2020-12-14 11:32:11 -05:00
parent e3ea141bf3
commit 2f28d2a96b
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
5 changed files with 178 additions and 79 deletions

View File

@ -158,8 +158,7 @@ func startControllers(
rand.Reader, rand.Reader,
func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) { func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) {
plog.Debug("setting hmac secret", "issuer", parent.Spec.Issuer) plog.Debug("setting hmac secret", "issuer", parent.Spec.Issuer)
secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer). secretCache.SetTokenHMACKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey])
SetTokenHMACKey(child.Data[symmetricsecrethelper.SecretDataKey])
}, },
), ),
kubeClient, kubeClient,
@ -177,8 +176,7 @@ func startControllers(
rand.Reader, rand.Reader,
func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) { func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) {
plog.Debug("setting state signature key", "issuer", parent.Spec.Issuer) plog.Debug("setting state signature key", "issuer", parent.Spec.Issuer)
secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer). secretCache.SetStateEncoderHashKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey])
SetStateEncoderHashKey(child.Data[symmetricsecrethelper.SecretDataKey])
}, },
), ),
kubeClient, kubeClient,
@ -196,8 +194,7 @@ func startControllers(
rand.Reader, rand.Reader,
func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) { func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) {
plog.Debug("setting state encryption key", "issuer", parent.Spec.Issuer) plog.Debug("setting state encryption key", "issuer", parent.Spec.Issuer)
secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer). secretCache.SetStateEncoderBlockKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey])
SetStateEncoderHashKey(child.Data[symmetricsecrethelper.SecretDataKey])
}, },
), ),
kubeClient, kubeClient,

View File

@ -37,7 +37,7 @@ type Manager struct {
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
cache *secret.Cache // in-memory cache of cryptographic material secretCache *secret.Cache // in-memory cache of cryptographic material
secretsClient corev1client.SecretInterface secretsClient corev1client.SecretInterface
} }
@ -49,7 +49,7 @@ func NewManager(
nextHandler http.Handler, nextHandler http.Handler,
dynamicJWKSProvider jwks.DynamicJWKSProvider, dynamicJWKSProvider jwks.DynamicJWKSProvider,
idpListGetter oidc.IDPListGetter, idpListGetter oidc.IDPListGetter,
cache *secret.Cache, secretCache *secret.Cache,
secretsClient corev1client.SecretInterface, secretsClient corev1client.SecretInterface,
) *Manager { ) *Manager {
return &Manager{ return &Manager{
@ -57,7 +57,7 @@ func NewManager(
nextHandler: nextHandler, nextHandler: nextHandler,
dynamicJWKSProvider: dynamicJWKSProvider, dynamicJWKSProvider: dynamicJWKSProvider,
idpListGetter: idpListGetter, idpListGetter: idpListGetter,
cache: cache, secretCache: secretCache,
secretsClient: secretsClient, secretsClient: secretsClient,
} }
} }
@ -79,28 +79,28 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
var csrfCookieEncoder = dynamiccodec.New( var csrfCookieEncoder = dynamiccodec.New(
oidc.CSRFCookieLifespan, oidc.CSRFCookieLifespan,
m.cache.GetCSRFCookieEncoderHashKey, m.secretCache.GetCSRFCookieEncoderHashKey,
m.cache.GetCSRFCookieEncoderBlockKey, func() []byte { return nil },
) )
for _, incomingProvider := range oidcProviders { for _, incomingProvider := range oidcProviders {
providerCache := m.cache.GetOIDCProviderCacheFor(incomingProvider.Issuer())
issuer := incomingProvider.Issuer() issuer := incomingProvider.Issuer()
issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath()
oidcTimeouts := oidc.DefaultOIDCTimeoutsConfiguration() oidcTimeouts := oidc.DefaultOIDCTimeoutsConfiguration()
tokenHMACKeyGetter := wrapGetter(incomingProvider.Issuer(), m.secretCache.GetTokenHMACKey)
// Use NullStorage for the authorize endpoint because we do not actually want to store anything until // Use NullStorage for the authorize endpoint because we do not actually want to store anything until
// the upstream callback endpoint is called later. // the upstream callback endpoint is called later.
oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, providerCache.GetTokenHMACKey, nil, oidcTimeouts) oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, tokenHMACKeyGetter, nil, oidcTimeouts)
// For all the other endpoints, make another oauth helper with exactly the same settings except use real storage. // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage.
oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, providerCache.GetTokenHMACKey, m.dynamicJWKSProvider, oidcTimeouts) oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, tokenHMACKeyGetter, m.dynamicJWKSProvider, oidcTimeouts)
var upstreamStateEncoder = dynamiccodec.New( var upstreamStateEncoder = dynamiccodec.New(
oidcTimeouts.UpstreamStateParamLifespan, oidcTimeouts.UpstreamStateParamLifespan,
providerCache.GetStateEncoderHashKey, wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderHashKey),
providerCache.GetStateEncoderBlockKey, wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderBlockKey),
) )
m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer) m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer)
@ -158,3 +158,9 @@ func (m *Manager) findHandler(req *http.Request) http.Handler {
return m.providerHandlers[strings.ToLower(req.Host)+"/"+req.URL.Path] return m.providerHandlers[strings.ToLower(req.Host)+"/"+req.URL.Path]
} }
func wrapGetter(issuer string, getter func(string) []byte) func() []byte {
return func() []byte {
return getter(issuer)
}
}

View File

@ -246,15 +246,13 @@ func TestManager(t *testing.T) {
cache := secret.Cache{} cache := secret.Cache{}
cache.SetCSRFCookieEncoderHashKey([]byte("fake-csrf-hash-secret")) cache.SetCSRFCookieEncoderHashKey([]byte("fake-csrf-hash-secret"))
oidcProvider1Cache := cache.GetOIDCProviderCacheFor(issuer1) cache.SetTokenHMACKey(issuer1, []byte("some secret 1 - must have at least 32 bytes"))
oidcProvider1Cache.SetStateEncoderHashKey([]byte("some-state-encoder-hash-key-1")) cache.SetStateEncoderHashKey(issuer1, []byte("some-state-encoder-hash-key-1"))
oidcProvider1Cache.SetStateEncoderBlockKey([]byte("16-bytes-STATE01")) cache.SetStateEncoderBlockKey(issuer1, []byte("16-bytes-STATE01"))
oidcProvider1Cache.SetTokenHMACKey([]byte("some secret 1 - must have at least 32 bytes"))
oidcProvider2Cache := cache.GetOIDCProviderCacheFor(issuer2) cache.SetTokenHMACKey(issuer2, []byte("some secret 2 - must have at least 32 bytes"))
oidcProvider2Cache.SetStateEncoderHashKey([]byte("some-state-encoder-hash-key-2")) cache.SetStateEncoderHashKey(issuer2, []byte("some-state-encoder-hash-key-2"))
oidcProvider2Cache.SetStateEncoderBlockKey([]byte("16-bytes-STATE02")) cache.SetStateEncoderBlockKey(issuer2, []byte("16-bytes-STATE02"))
oidcProvider2Cache.SetTokenHMACKey([]byte("some secret 2 - must have at least 32 bytes"))
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, &cache, secretsClient) subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, &cache, secretsClient)
}) })

View File

@ -3,77 +3,69 @@
package secret package secret
// TODO: synchronize me. import (
// TODO: use SetIssuerXXX() functions instead of returning a struct so that we don't have to worry about reentrancy. "sync"
"sync/atomic"
)
type Cache struct { type Cache struct {
csrfCookieEncoderHashKey []byte csrfCookieEncoderHashKey atomic.Value
csrfCookieEncoderBlockKey []byte oidcProviderCacheMap sync.Map
oidcProviderCacheMap map[string]*OIDCProviderCache }
// New returns an empty Cache.
func New() *Cache { return &Cache{} }
type oidcProviderCache struct {
tokenHMACKey atomic.Value
stateEncoderHashKey atomic.Value
stateEncoderBlockKey atomic.Value
} }
func (c *Cache) GetCSRFCookieEncoderHashKey() []byte { func (c *Cache) GetCSRFCookieEncoderHashKey() []byte {
return c.csrfCookieEncoderHashKey return bytesOrNil(c.csrfCookieEncoderHashKey.Load())
} }
func (c *Cache) SetCSRFCookieEncoderHashKey(key []byte) { func (c *Cache) SetCSRFCookieEncoderHashKey(key []byte) {
c.csrfCookieEncoderHashKey = key c.csrfCookieEncoderHashKey.Store(key)
} }
func (c *Cache) GetCSRFCookieEncoderBlockKey() []byte { func (c *Cache) GetTokenHMACKey(oidcIssuer string) []byte {
return c.csrfCookieEncoderBlockKey return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).tokenHMACKey.Load())
} }
func (c *Cache) SetCSRFCookieEncoderBlockKey(key []byte) { func (c *Cache) SetTokenHMACKey(oidcIssuer string, key []byte) {
c.csrfCookieEncoderBlockKey = key c.getOIDCProviderCache(oidcIssuer).tokenHMACKey.Store(key)
} }
func (c *Cache) GetOIDCProviderCacheFor(oidcIssuer string) *OIDCProviderCache { func (c *Cache) GetStateEncoderHashKey(oidcIssuer string) []byte {
oidcProvider, ok := c.oidcProviderCaches()[oidcIssuer] return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).stateEncoderHashKey.Load())
}
func (c *Cache) SetStateEncoderHashKey(oidcIssuer string, key []byte) {
c.getOIDCProviderCache(oidcIssuer).stateEncoderHashKey.Store(key)
}
func (c *Cache) GetStateEncoderBlockKey(oidcIssuer string) []byte {
return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).stateEncoderBlockKey.Load())
}
func (c *Cache) SetStateEncoderBlockKey(oidcIssuer string, key []byte) {
c.getOIDCProviderCache(oidcIssuer).stateEncoderBlockKey.Store(key)
}
func (c *Cache) getOIDCProviderCache(oidcIssuer string) *oidcProviderCache {
value, ok := c.oidcProviderCacheMap.Load(oidcIssuer)
if !ok { if !ok {
oidcProvider = &OIDCProviderCache{} value = &oidcProviderCache{}
c.oidcProviderCaches()[oidcIssuer] = oidcProvider c.oidcProviderCacheMap.Store(oidcIssuer, value)
} }
return oidcProvider return value.(*oidcProviderCache)
} }
func (c *Cache) SetOIDCProviderCacheFor(oidcIssuer string, oidcProviderCache *OIDCProviderCache) { func bytesOrNil(b interface{}) []byte {
c.oidcProviderCaches()[oidcIssuer] = oidcProviderCache if b == nil {
return nil
} }
return b.([]byte)
func (c *Cache) oidcProviderCaches() map[string]*OIDCProviderCache {
if c.oidcProviderCacheMap == nil {
c.oidcProviderCacheMap = map[string]*OIDCProviderCache{}
}
return c.oidcProviderCacheMap
}
type OIDCProviderCache struct {
tokenHMACKey []byte
stateEncoderHashKey []byte
stateEncoderBlockKey []byte
}
func (o *OIDCProviderCache) GetTokenHMACKey() []byte {
return o.tokenHMACKey
}
func (o *OIDCProviderCache) SetTokenHMACKey(key []byte) {
o.tokenHMACKey = key
}
func (o *OIDCProviderCache) GetStateEncoderHashKey() []byte {
return o.stateEncoderHashKey
}
func (o *OIDCProviderCache) SetStateEncoderHashKey(key []byte) {
o.stateEncoderHashKey = key
}
func (o *OIDCProviderCache) GetStateEncoderBlockKey() []byte {
return o.stateEncoderBlockKey
}
func (o *OIDCProviderCache) SetStateEncoderBlockKey(key []byte) {
o.stateEncoderBlockKey = key
} }

View File

@ -0,0 +1,106 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package secret
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
const (
issuer = "some-issuer"
otherIssuer = "other-issuer"
)
var (
csrfCookieEncoderHashKey = []byte("csrf-cookie-encoder-hash-key")
tokenHMACKey = []byte("token-hmac-key")
stateEncoderHashKey = []byte("state-encoder-hash-key")
otherStateEncoderHashKey = []byte("other-state-encoder-hash-key")
stateEncoderBlockKey = []byte("state-encoder-block-key")
)
func TestCache(t *testing.T) {
c := New()
// Validate we get a nil return value when stuff does not exist.
require.Nil(t, c.GetCSRFCookieEncoderHashKey())
require.Nil(t, c.GetTokenHMACKey(issuer))
require.Nil(t, c.GetStateEncoderHashKey(issuer))
require.Nil(t, c.GetStateEncoderBlockKey(issuer))
// Validate we get some nil and non-nil values when some stuff exists.
c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey)
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
require.Nil(t, c.GetTokenHMACKey(issuer))
c.SetStateEncoderHashKey(issuer, stateEncoderHashKey)
require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
require.Nil(t, c.GetStateEncoderBlockKey(issuer))
// Validate we get non-nil values when all stuff exists.
c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey)
c.SetTokenHMACKey(issuer, tokenHMACKey)
c.SetStateEncoderHashKey(issuer, otherStateEncoderHashKey)
c.SetStateEncoderBlockKey(issuer, stateEncoderBlockKey)
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer))
require.Equal(t, otherStateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer))
// Validate that stuff is still nil for an unknown issuer.
require.Nil(t, c.GetTokenHMACKey(otherIssuer))
require.Nil(t, c.GetStateEncoderHashKey(otherIssuer))
require.Nil(t, c.GetStateEncoderBlockKey(otherIssuer))
}
// TestCacheSynchronized should mimic the behavior of an OIDCProvider: multiple goroutines
// read the same fields, sequentially, from the cache.
func TestCacheSynchronized(t *testing.T) {
c := New()
c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey)
c.SetTokenHMACKey(issuer, tokenHMACKey)
c.SetStateEncoderHashKey(issuer, stateEncoderHashKey)
c.SetStateEncoderBlockKey(issuer, stateEncoderBlockKey)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
eg, _ := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 100; i++ {
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer))
require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer))
}
return nil
})
eg.Go(func() error {
for i := 0; i < 100; i++ {
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer))
require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer))
}
return nil
})
eg.Go(func() error {
for i := 0; i < 100; i++ {
require.Nil(t, c.GetTokenHMACKey(otherIssuer))
require.Nil(t, c.GetStateEncoderHashKey(otherIssuer))
require.Nil(t, c.GetStateEncoderBlockKey(otherIssuer))
}
return nil
})
require.NoError(t, eg.Wait())
}