Synchronize the OIDCProvider secrets cache
Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
parent
e3ea141bf3
commit
2f28d2a96b
@ -158,8 +158,7 @@ func startControllers(
|
||||
rand.Reader,
|
||||
func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) {
|
||||
plog.Debug("setting hmac secret", "issuer", parent.Spec.Issuer)
|
||||
secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer).
|
||||
SetTokenHMACKey(child.Data[symmetricsecrethelper.SecretDataKey])
|
||||
secretCache.SetTokenHMACKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey])
|
||||
},
|
||||
),
|
||||
kubeClient,
|
||||
@ -177,8 +176,7 @@ func startControllers(
|
||||
rand.Reader,
|
||||
func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) {
|
||||
plog.Debug("setting state signature key", "issuer", parent.Spec.Issuer)
|
||||
secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer).
|
||||
SetStateEncoderHashKey(child.Data[symmetricsecrethelper.SecretDataKey])
|
||||
secretCache.SetStateEncoderHashKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey])
|
||||
},
|
||||
),
|
||||
kubeClient,
|
||||
@ -196,8 +194,7 @@ func startControllers(
|
||||
rand.Reader,
|
||||
func(parent *configv1alpha1.OIDCProvider, child *corev1.Secret) {
|
||||
plog.Debug("setting state encryption key", "issuer", parent.Spec.Issuer)
|
||||
secretCache.GetOIDCProviderCacheFor(parent.Spec.Issuer).
|
||||
SetStateEncoderHashKey(child.Data[symmetricsecrethelper.SecretDataKey])
|
||||
secretCache.SetStateEncoderBlockKey(parent.Spec.Issuer, child.Data[symmetricsecrethelper.SecretDataKey])
|
||||
},
|
||||
),
|
||||
kubeClient,
|
||||
|
@ -37,7 +37,7 @@ type Manager struct {
|
||||
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
||||
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
|
||||
idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
|
||||
cache *secret.Cache // in-memory cache of cryptographic material
|
||||
secretCache *secret.Cache // in-memory cache of cryptographic material
|
||||
secretsClient corev1client.SecretInterface
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ func NewManager(
|
||||
nextHandler http.Handler,
|
||||
dynamicJWKSProvider jwks.DynamicJWKSProvider,
|
||||
idpListGetter oidc.IDPListGetter,
|
||||
cache *secret.Cache,
|
||||
secretCache *secret.Cache,
|
||||
secretsClient corev1client.SecretInterface,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
@ -57,7 +57,7 @@ func NewManager(
|
||||
nextHandler: nextHandler,
|
||||
dynamicJWKSProvider: dynamicJWKSProvider,
|
||||
idpListGetter: idpListGetter,
|
||||
cache: cache,
|
||||
secretCache: secretCache,
|
||||
secretsClient: secretsClient,
|
||||
}
|
||||
}
|
||||
@ -79,28 +79,28 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
||||
|
||||
var csrfCookieEncoder = dynamiccodec.New(
|
||||
oidc.CSRFCookieLifespan,
|
||||
m.cache.GetCSRFCookieEncoderHashKey,
|
||||
m.cache.GetCSRFCookieEncoderBlockKey,
|
||||
m.secretCache.GetCSRFCookieEncoderHashKey,
|
||||
func() []byte { return nil },
|
||||
)
|
||||
|
||||
for _, incomingProvider := range oidcProviders {
|
||||
providerCache := m.cache.GetOIDCProviderCacheFor(incomingProvider.Issuer())
|
||||
|
||||
issuer := incomingProvider.Issuer()
|
||||
issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath()
|
||||
oidcTimeouts := oidc.DefaultOIDCTimeoutsConfiguration()
|
||||
|
||||
tokenHMACKeyGetter := wrapGetter(incomingProvider.Issuer(), m.secretCache.GetTokenHMACKey)
|
||||
|
||||
// Use NullStorage for the authorize endpoint because we do not actually want to store anything until
|
||||
// the upstream callback endpoint is called later.
|
||||
oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, providerCache.GetTokenHMACKey, nil, oidcTimeouts)
|
||||
oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, tokenHMACKeyGetter, nil, oidcTimeouts)
|
||||
|
||||
// For all the other endpoints, make another oauth helper with exactly the same settings except use real storage.
|
||||
oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, providerCache.GetTokenHMACKey, m.dynamicJWKSProvider, oidcTimeouts)
|
||||
oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, tokenHMACKeyGetter, m.dynamicJWKSProvider, oidcTimeouts)
|
||||
|
||||
var upstreamStateEncoder = dynamiccodec.New(
|
||||
oidcTimeouts.UpstreamStateParamLifespan,
|
||||
providerCache.GetStateEncoderHashKey,
|
||||
providerCache.GetStateEncoderBlockKey,
|
||||
wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderHashKey),
|
||||
wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderBlockKey),
|
||||
)
|
||||
|
||||
m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer)
|
||||
@ -158,3 +158,9 @@ func (m *Manager) findHandler(req *http.Request) http.Handler {
|
||||
|
||||
return m.providerHandlers[strings.ToLower(req.Host)+"/"+req.URL.Path]
|
||||
}
|
||||
|
||||
func wrapGetter(issuer string, getter func(string) []byte) func() []byte {
|
||||
return func() []byte {
|
||||
return getter(issuer)
|
||||
}
|
||||
}
|
||||
|
@ -246,15 +246,13 @@ func TestManager(t *testing.T) {
|
||||
cache := secret.Cache{}
|
||||
cache.SetCSRFCookieEncoderHashKey([]byte("fake-csrf-hash-secret"))
|
||||
|
||||
oidcProvider1Cache := cache.GetOIDCProviderCacheFor(issuer1)
|
||||
oidcProvider1Cache.SetStateEncoderHashKey([]byte("some-state-encoder-hash-key-1"))
|
||||
oidcProvider1Cache.SetStateEncoderBlockKey([]byte("16-bytes-STATE01"))
|
||||
oidcProvider1Cache.SetTokenHMACKey([]byte("some secret 1 - must have at least 32 bytes"))
|
||||
cache.SetTokenHMACKey(issuer1, []byte("some secret 1 - must have at least 32 bytes"))
|
||||
cache.SetStateEncoderHashKey(issuer1, []byte("some-state-encoder-hash-key-1"))
|
||||
cache.SetStateEncoderBlockKey(issuer1, []byte("16-bytes-STATE01"))
|
||||
|
||||
oidcProvider2Cache := cache.GetOIDCProviderCacheFor(issuer2)
|
||||
oidcProvider2Cache.SetStateEncoderHashKey([]byte("some-state-encoder-hash-key-2"))
|
||||
oidcProvider2Cache.SetStateEncoderBlockKey([]byte("16-bytes-STATE02"))
|
||||
oidcProvider2Cache.SetTokenHMACKey([]byte("some secret 2 - must have at least 32 bytes"))
|
||||
cache.SetTokenHMACKey(issuer2, []byte("some secret 2 - must have at least 32 bytes"))
|
||||
cache.SetStateEncoderHashKey(issuer2, []byte("some-state-encoder-hash-key-2"))
|
||||
cache.SetStateEncoderBlockKey(issuer2, []byte("16-bytes-STATE02"))
|
||||
|
||||
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, &cache, secretsClient)
|
||||
})
|
||||
|
@ -3,77 +3,69 @@
|
||||
|
||||
package secret
|
||||
|
||||
// TODO: synchronize me.
|
||||
// TODO: use SetIssuerXXX() functions instead of returning a struct so that we don't have to worry about reentrancy.
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type Cache struct {
|
||||
csrfCookieEncoderHashKey []byte
|
||||
csrfCookieEncoderBlockKey []byte
|
||||
oidcProviderCacheMap map[string]*OIDCProviderCache
|
||||
csrfCookieEncoderHashKey atomic.Value
|
||||
oidcProviderCacheMap sync.Map
|
||||
}
|
||||
|
||||
// New returns an empty Cache.
|
||||
func New() *Cache { return &Cache{} }
|
||||
|
||||
type oidcProviderCache struct {
|
||||
tokenHMACKey atomic.Value
|
||||
stateEncoderHashKey atomic.Value
|
||||
stateEncoderBlockKey atomic.Value
|
||||
}
|
||||
|
||||
func (c *Cache) GetCSRFCookieEncoderHashKey() []byte {
|
||||
return c.csrfCookieEncoderHashKey
|
||||
return bytesOrNil(c.csrfCookieEncoderHashKey.Load())
|
||||
}
|
||||
|
||||
func (c *Cache) SetCSRFCookieEncoderHashKey(key []byte) {
|
||||
c.csrfCookieEncoderHashKey = key
|
||||
c.csrfCookieEncoderHashKey.Store(key)
|
||||
}
|
||||
|
||||
func (c *Cache) GetCSRFCookieEncoderBlockKey() []byte {
|
||||
return c.csrfCookieEncoderBlockKey
|
||||
func (c *Cache) GetTokenHMACKey(oidcIssuer string) []byte {
|
||||
return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).tokenHMACKey.Load())
|
||||
}
|
||||
|
||||
func (c *Cache) SetCSRFCookieEncoderBlockKey(key []byte) {
|
||||
c.csrfCookieEncoderBlockKey = key
|
||||
func (c *Cache) SetTokenHMACKey(oidcIssuer string, key []byte) {
|
||||
c.getOIDCProviderCache(oidcIssuer).tokenHMACKey.Store(key)
|
||||
}
|
||||
|
||||
func (c *Cache) GetOIDCProviderCacheFor(oidcIssuer string) *OIDCProviderCache {
|
||||
oidcProvider, ok := c.oidcProviderCaches()[oidcIssuer]
|
||||
func (c *Cache) GetStateEncoderHashKey(oidcIssuer string) []byte {
|
||||
return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).stateEncoderHashKey.Load())
|
||||
}
|
||||
|
||||
func (c *Cache) SetStateEncoderHashKey(oidcIssuer string, key []byte) {
|
||||
c.getOIDCProviderCache(oidcIssuer).stateEncoderHashKey.Store(key)
|
||||
}
|
||||
|
||||
func (c *Cache) GetStateEncoderBlockKey(oidcIssuer string) []byte {
|
||||
return bytesOrNil(c.getOIDCProviderCache(oidcIssuer).stateEncoderBlockKey.Load())
|
||||
}
|
||||
|
||||
func (c *Cache) SetStateEncoderBlockKey(oidcIssuer string, key []byte) {
|
||||
c.getOIDCProviderCache(oidcIssuer).stateEncoderBlockKey.Store(key)
|
||||
}
|
||||
|
||||
func (c *Cache) getOIDCProviderCache(oidcIssuer string) *oidcProviderCache {
|
||||
value, ok := c.oidcProviderCacheMap.Load(oidcIssuer)
|
||||
if !ok {
|
||||
oidcProvider = &OIDCProviderCache{}
|
||||
c.oidcProviderCaches()[oidcIssuer] = oidcProvider
|
||||
value = &oidcProviderCache{}
|
||||
c.oidcProviderCacheMap.Store(oidcIssuer, value)
|
||||
}
|
||||
return oidcProvider
|
||||
return value.(*oidcProviderCache)
|
||||
}
|
||||
|
||||
func (c *Cache) SetOIDCProviderCacheFor(oidcIssuer string, oidcProviderCache *OIDCProviderCache) {
|
||||
c.oidcProviderCaches()[oidcIssuer] = oidcProviderCache
|
||||
func bytesOrNil(b interface{}) []byte {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cache) oidcProviderCaches() map[string]*OIDCProviderCache {
|
||||
if c.oidcProviderCacheMap == nil {
|
||||
c.oidcProviderCacheMap = map[string]*OIDCProviderCache{}
|
||||
}
|
||||
return c.oidcProviderCacheMap
|
||||
}
|
||||
|
||||
type OIDCProviderCache struct {
|
||||
tokenHMACKey []byte
|
||||
stateEncoderHashKey []byte
|
||||
stateEncoderBlockKey []byte
|
||||
}
|
||||
|
||||
func (o *OIDCProviderCache) GetTokenHMACKey() []byte {
|
||||
return o.tokenHMACKey
|
||||
}
|
||||
|
||||
func (o *OIDCProviderCache) SetTokenHMACKey(key []byte) {
|
||||
o.tokenHMACKey = key
|
||||
}
|
||||
|
||||
func (o *OIDCProviderCache) GetStateEncoderHashKey() []byte {
|
||||
return o.stateEncoderHashKey
|
||||
}
|
||||
|
||||
func (o *OIDCProviderCache) SetStateEncoderHashKey(key []byte) {
|
||||
o.stateEncoderHashKey = key
|
||||
}
|
||||
|
||||
func (o *OIDCProviderCache) GetStateEncoderBlockKey() []byte {
|
||||
return o.stateEncoderBlockKey
|
||||
}
|
||||
|
||||
func (o *OIDCProviderCache) SetStateEncoderBlockKey(key []byte) {
|
||||
o.stateEncoderBlockKey = key
|
||||
return b.([]byte)
|
||||
}
|
||||
|
106
internal/secret/cache_test.go
Normal file
106
internal/secret/cache_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package secret
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
issuer = "some-issuer"
|
||||
otherIssuer = "other-issuer"
|
||||
)
|
||||
|
||||
var (
|
||||
csrfCookieEncoderHashKey = []byte("csrf-cookie-encoder-hash-key")
|
||||
tokenHMACKey = []byte("token-hmac-key")
|
||||
stateEncoderHashKey = []byte("state-encoder-hash-key")
|
||||
otherStateEncoderHashKey = []byte("other-state-encoder-hash-key")
|
||||
stateEncoderBlockKey = []byte("state-encoder-block-key")
|
||||
)
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
c := New()
|
||||
|
||||
// Validate we get a nil return value when stuff does not exist.
|
||||
require.Nil(t, c.GetCSRFCookieEncoderHashKey())
|
||||
require.Nil(t, c.GetTokenHMACKey(issuer))
|
||||
require.Nil(t, c.GetStateEncoderHashKey(issuer))
|
||||
require.Nil(t, c.GetStateEncoderBlockKey(issuer))
|
||||
|
||||
// Validate we get some nil and non-nil values when some stuff exists.
|
||||
c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey)
|
||||
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
|
||||
require.Nil(t, c.GetTokenHMACKey(issuer))
|
||||
c.SetStateEncoderHashKey(issuer, stateEncoderHashKey)
|
||||
require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
|
||||
require.Nil(t, c.GetStateEncoderBlockKey(issuer))
|
||||
|
||||
// Validate we get non-nil values when all stuff exists.
|
||||
c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey)
|
||||
c.SetTokenHMACKey(issuer, tokenHMACKey)
|
||||
c.SetStateEncoderHashKey(issuer, otherStateEncoderHashKey)
|
||||
c.SetStateEncoderBlockKey(issuer, stateEncoderBlockKey)
|
||||
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
|
||||
require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer))
|
||||
require.Equal(t, otherStateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
|
||||
require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer))
|
||||
|
||||
// Validate that stuff is still nil for an unknown issuer.
|
||||
require.Nil(t, c.GetTokenHMACKey(otherIssuer))
|
||||
require.Nil(t, c.GetStateEncoderHashKey(otherIssuer))
|
||||
require.Nil(t, c.GetStateEncoderBlockKey(otherIssuer))
|
||||
}
|
||||
|
||||
// TestCacheSynchronized should mimic the behavior of an OIDCProvider: multiple goroutines
|
||||
// read the same fields, sequentially, from the cache.
|
||||
func TestCacheSynchronized(t *testing.T) {
|
||||
c := New()
|
||||
|
||||
c.SetCSRFCookieEncoderHashKey(csrfCookieEncoderHashKey)
|
||||
c.SetTokenHMACKey(issuer, tokenHMACKey)
|
||||
c.SetStateEncoderHashKey(issuer, stateEncoderHashKey)
|
||||
c.SetStateEncoderBlockKey(issuer, stateEncoderBlockKey)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
|
||||
defer cancel()
|
||||
|
||||
eg, _ := errgroup.WithContext(ctx)
|
||||
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 100; i++ {
|
||||
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
|
||||
require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer))
|
||||
require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
|
||||
require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 100; i++ {
|
||||
require.Equal(t, csrfCookieEncoderHashKey, c.GetCSRFCookieEncoderHashKey())
|
||||
require.Equal(t, tokenHMACKey, c.GetTokenHMACKey(issuer))
|
||||
require.Equal(t, stateEncoderHashKey, c.GetStateEncoderHashKey(issuer))
|
||||
require.Equal(t, stateEncoderBlockKey, c.GetStateEncoderBlockKey(issuer))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 100; i++ {
|
||||
require.Nil(t, c.GetTokenHMACKey(otherIssuer))
|
||||
require.Nil(t, c.GetStateEncoderHashKey(otherIssuer))
|
||||
require.Nil(t, c.GetStateEncoderBlockKey(otherIssuer))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
require.NoError(t, eg.Wait())
|
||||
}
|
Loading…
Reference in New Issue
Block a user