diff --git a/cmd/pinniped-supervisor/main.go b/cmd/pinniped-supervisor/main.go index eb2fe2c4..43b3fa48 100644 --- a/cmd/pinniped-supervisor/main.go +++ b/cmd/pinniped-supervisor/main.go @@ -158,8 +158,7 @@ func startControllers( rand.Reader, func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) { plog.Debug("setting hmac secret", "issuer", parent.Spec.Issuer) - secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer). - SetTokenHMACKey(child.Data[symmetricsecrethelper.SecretDataKey]) + secretCache.SetTokenHMACKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey]) }, ), kubeClient, @@ -177,8 +176,7 @@ func startControllers( rand.Reader, func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) { plog.Debug("setting state signature key", "issuer", parent.Spec.Issuer) - secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer). - SetStateEncoderHashKey(child.Data[symmetricsecrethelper.SecretDataKey]) + secretCache.SetStateEncoderHashKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey]) }, ), kubeClient, @@ -196,8 +194,7 @@ func startControllers( rand.Reader, func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) { plog.Debug("setting state encryption key", "issuer", parent.Spec.Issuer) - secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer). - SetStateEncoderHashKey(child.Data[symmetricsecrethelper.SecretDataKey]) + secretCache.SetStateEncoderBlockKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey]) }, ), kubeClient, diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index d16ed682..fd442765 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -37,7 +37,7 @@ type Manager struct { nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs - cache *secret.Cache // in-memory cache of cryptographic material + secretCache *secret.Cache // in-memory cache of cryptographic material secretsClient corev1client.SecretInterface } @@ -49,7 +49,7 @@ func NewManager( nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter, - cache *secret.Cache, + secretCache *secret.Cache, secretsClient corev1client.SecretInterface, ) *Manager { return &Manager{ @@ -57,7 +57,7 @@ func NewManager( nextHandler: nextHandler, dynamicJWKSProvider: dynamicJWKSProvider, idpListGetter: idpListGetter, - cache: cache, + secretCache: secretCache, secretsClient: secretsClient, } } @@ -79,28 +79,28 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { var csrfCookieEncoder = dynamiccodec.New( oidc.CSRFCookieLifespan, - m.cache.GetCSRFCookieEncoderHashKey, - m.cache.GetCSRFCookieEncoderBlockKey, + m.secretCache.GetCSRFCookieEncoderHashKey, + func() []byte { return nil }, ) for _, incomingProvider := range oidcProviders { - providerCache := m.cache.GetOIDCProviderCacheFor(incomingProvider.Issuer()) - issuer := incomingProvider.Issuer() issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() oidcTimeouts := oidc.DefaultOIDCTimeoutsConfiguration() + tokenHMACKeyGetter := wrapGetter(incomingProvider.Issuer(), m.secretCache.GetTokenHMACKey) + // Use NullStorage for the authorize endpoint because we do not actually want to store anything until // the upstream callback endpoint is called later. - oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, providerCache.GetTokenHMACKey, nil, oidcTimeouts) + oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, tokenHMACKeyGetter, nil, oidcTimeouts) // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage. - oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, providerCache.GetTokenHMACKey, m.dynamicJWKSProvider, oidcTimeouts) + oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, tokenHMACKeyGetter, m.dynamicJWKSProvider, oidcTimeouts) var upstreamStateEncoder = dynamiccodec.New( oidcTimeouts.UpstreamStateParamLifespan, - providerCache.GetStateEncoderHashKey, - providerCache.GetStateEncoderBlockKey, + wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderHashKey), + wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderBlockKey), ) m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer) @@ -158,3 +158,9 @@ func (m *Manager) findHandler(req *http.Request) http.Handler { return m.providerHandlers[strings.ToLower(req.Host)+"/"+req.URL.Path] } + +func wrapGetter(issuer string, getter func(string) []byte) func() []byte { + return func() []byte { + return getter(issuer) + } +} diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index 78b2cfd2..18ec0036 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -246,15 +246,13 @@ func TestManager(t *testing.T) { cache := secret.Cache{} cache.SetCSRFCookieEncoderHashKey([]byte("fake-csrf-hash-secret")) - oidcProvider1Cache := cache.GetOIDCProviderCacheFor(issuer1) - oidcProvider1Cache.SetStateEncoderHashKey([]byte("some-state-encoder-hash-key-1")) - oidcProvider1Cache.SetStateEncoderBlockKey([]byte("16-bytes-STATE01")) - oidcProvider1Cache.SetTokenHMACKey([]byte("some secret 1 - must have at least 32 bytes")) + cache.SetTokenHMACKey(issuer1, []byte("some secret 1 - must have at least 32 bytes")) + cache.SetStateEncoderHashKey(issuer1, []byte("some-state-encoder-hash-key-1")) + cache.SetStateEncoderBlockKey(issuer1, []byte("16-bytes-STATE01")) - oidcProvider2Cache := cache.GetOIDCProviderCacheFor(issuer2) - oidcProvider2Cache.SetStateEncoderHashKey([]byte("some-state-encoder-hash-key-2")) - oidcProvider2Cache.SetStateEncoderBlockKey([]byte("16-bytes-STATE02")) - oidcProvider2Cache.SetTokenHMACKey([]byte("some secret 2 - must have at least 32 bytes")) + cache.SetTokenHMACKey(issuer2, []byte("some secret 2 - must have at least 32 bytes")) + cache.SetStateEncoderHashKey(issuer2, []byte("some-state-encoder-hash-key-2")) + cache.SetStateEncoderBlockKey(issuer2, []byte("16-bytes-STATE02")) subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, &cache, secretsClient) }) diff --git a/internal/secret/cache.go b/internal/secret/cache.go index e8b39965..6c0f4063 100644 --- a/internal/secret/cache.go +++ b/internal/secret/cache.go @@ -3,77 +3,69 @@ package secret -// TODO: synchronize me. -// TODO: use SetIssuerXXX() functions instead of returning a struct so that we don't have to worry about reentrancy. +import ( + "sync" + "sync/atomic" +) type Cache struct { - csrfCookieEncoderHashKey []byte - csrfCookieEncoderBlockKey []byte - oidcProviderCacheMap map[string]*OIDCProviderCache + csrfCookieEncoderHashKey atomic.Value + oidcProviderCacheMap sync.Map +} + +// New returns an empty Cache. +func New() *Cache { return &Cache{} } + +type oidcProviderCache struct { + tokenHMACKey atomic.Value + stateEncoderHashKey atomic.Value + stateEncoderBlockKey atomic.Value } func (c *Cache) GetCSRFCookieEncoderHashKey() []byte { - return c.csrfCookieEncoderHashKey + return bytesOrNil(c.csrfCookieEncoderHashKey.Load()) } func (c *Cache) SetCSRFCookieEncoderHashKey(key []byte) { - c.csrfCookieEncoderHashKey = key + c.csrfCookieEncoderHashKey.Store(key) } -func (c *Cache) GetCSRFCookieEncoderBlockKey() []byte { - return c.csrfCookieEncoderBlockKey +func (c *Cache) GetTokenHMACKey(oidcIssuer string) []byte { + return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).tokenHMACKey.Load()) } -func (c *Cache) SetCSRFCookieEncoderBlockKey(key []byte) { - c.csrfCookieEncoderBlockKey = key +func (c *Cache) SetTokenHMACKey(oidcIssuer string, key []byte) { + c.getOIDCProviderCache(oidcIssuer).tokenHMACKey.Store(key) } -func (c *Cache) GetOIDCProviderCacheFor(oidcIssuer string) *OIDCProviderCache { - oidcProvider, ok := c.oidcProviderCaches()[oidcIssuer] +func (c *Cache) GetStateEncoderHashKey(oidcIssuer string) []byte { + return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).stateEncoderHashKey.Load()) +} + +func (c *Cache) SetStateEncoderHashKey(oidcIssuer string, key []byte) { + c.getOIDCProviderCache(oidcIssuer).stateEncoderHashKey.Store(key) +} + +func (c *Cache) GetStateEncoderBlockKey(oidcIssuer string) []byte { + return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).stateEncoderBlockKey.Load()) +} + +func (c *Cache) SetStateEncoderBlockKey(oidcIssuer string, key []byte) { + c.getOIDCProviderCache(oidcIssuer).stateEncoderBlockKey.Store(key) +} + +func (c *Cache) getOIDCProviderCache(oidcIssuer string) *oidcProviderCache { + value, ok := c.oidcProviderCacheMap.Load(oidcIssuer) if !ok { - oidcProvider = &OIDCProviderCache{} - c.oidcProviderCaches()[oidcIssuer] = oidcProvider + value = &oidcProviderCache{} + c.oidcProviderCacheMap.Store(oidcIssuer, value) } - return oidcProvider + return value.(*oidcProviderCache) } -func (c *Cache) SetOIDCProviderCacheFor(oidcIssuer string, oidcProviderCache *OIDCProviderCache) { - c.oidcProviderCaches()[oidcIssuer] = oidcProviderCache -} - -func (c *Cache) oidcProviderCaches() map[string]*OIDCProviderCache { - if c.oidcProviderCacheMap == nil { - c.oidcProviderCacheMap = map[string]*OIDCProviderCache{} +func bytesOrNil(b interface{}) []byte { + if b == nil { + return nil } - return c.oidcProviderCacheMap -} - -type OIDCProviderCache struct { - tokenHMACKey []byte - stateEncoderHashKey []byte - stateEncoderBlockKey []byte -} - -func (o *OIDCProviderCache) GetTokenHMACKey() []byte { - return o.tokenHMACKey -} - -func (o *OIDCProviderCache) SetTokenHMACKey(key []byte) { - o.tokenHMACKey = key -} - -func (o *OIDCProviderCache) GetStateEncoderHashKey() []byte { - return o.stateEncoderHashKey -} - -func (o *OIDCProviderCache) SetStateEncoderHashKey(key []byte) { - o.stateEncoderHashKey = key -} - -func (o *OIDCProviderCache) GetStateEncoderBlockKey() []byte { - return o.stateEncoderBlockKey -} - -func (o *OIDCProviderCache) SetStateEncoderBlockKey(key []byte) { - o.stateEncoderBlockKey = key + return b.([]byte) } diff --git a/internal/secret/cache_test.go b/internal/secret/cache_test.go new file mode 100644 index 00000000..40fdf612 --- /dev/null +++ b/internal/secret/cache_test.go @@ -0,0 +1,106 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package secret + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +const ( + issuer = "some-issuer" + otherIssuer = "other-issuer" +) + +var ( + csrfCookieEncoderHashKey = []byte("csrf-cookie-encoder-hash-key") + tokenHMACKey = []byte("token-hmac-key") + stateEncoderHashKey = []byte("state-encoder-hash-key") + otherStateEncoderHashKey = []byte("other-state-encoder-hash-key") + stateEncoderBlockKey = []byte("state-encoder-block-key") +) + +func TestCache(t *testing.T) { + c := New() + + // Validate we get a nil return value when stuff does not exist. + require.Nil(t, c.GetCSRFCookieEncoderHashKey()) + require.Nil(t, c.GetTokenHMACKey(issuer)) + require.Nil(t, c.GetStateEncoderHashKey(issuer)) + require.Nil(t, c.GetStateEncoderBlockKey(issuer)) + + // Validate we get some nil and non-nil values when some stuff exists. + c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey) + require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey()) + require.Nil(t, c.GetTokenHMACKey(issuer)) + c.SetStateEncoderHashKey(issuer, stateEncoderHashKey) + require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer)) + require.Nil(t, c.GetStateEncoderBlockKey(issuer)) + + // Validate we get non-nil values when all stuff exists. + c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey) + c.SetTokenHMACKey(issuer, tokenHMACKey) + c.SetStateEncoderHashKey(issuer, otherStateEncoderHashKey) + c.SetStateEncoderBlockKey(issuer, stateEncoderBlockKey) + require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey()) + require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer)) + require.Equal(t, otherStateEncoderHashKey, c.GetStateEncoderHashKey(issuer)) + require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer)) + + // Validate that stuff is still nil for an unknown issuer. + require.Nil(t, c.GetTokenHMACKey(otherIssuer)) + require.Nil(t, c.GetStateEncoderHashKey(otherIssuer)) + require.Nil(t, c.GetStateEncoderBlockKey(otherIssuer)) +} + +// TestCacheSynchronized should mimic the behavior of an OIDCProvider: multiple goroutines +// read the same fields, sequentially, from the cache. +func TestCacheSynchronized(t *testing.T) { + c := New() + + c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey) + c.SetTokenHMACKey(issuer, tokenHMACKey) + c.SetStateEncoderHashKey(issuer, stateEncoderHashKey) + c.SetStateEncoderBlockKey(issuer, stateEncoderBlockKey) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + eg, _ := errgroup.WithContext(ctx) + + eg.Go(func() error { + for i := 0; i < 100; i++ { + require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey()) + require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer)) + require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer)) + require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer)) + } + return nil + }) + + eg.Go(func() error { + for i := 0; i < 100; i++ { + require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey()) + require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer)) + require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer)) + require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer)) + } + return nil + }) + + eg.Go(func() error { + for i := 0; i < 100; i++ { + require.Nil(t, c.GetTokenHMACKey(otherIssuer)) + require.Nil(t, c.GetStateEncoderHashKey(otherIssuer)) + require.Nil(t, c.GetStateEncoderBlockKey(otherIssuer)) + } + return nil + }) + + require.NoError(t, eg.Wait()) +}