Use kube storage for the supervisor callback endpoint's fosite sessions

This commit is contained in:
Ryan Richard 2020-12-02 17:39:45 -08:00
parent 64ef53402d
commit 95093ab0af
3 changed files with 107 additions and 30 deletions

View File

@ -196,7 +196,12 @@ func run(serverInstallationNamespace string, cfg *supervisor.Config) error {
dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider() dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider()
// OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux. // OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux.
oidProvidersManager := manager.NewManager(healthMux, dynamicJWKSProvider, dynamicUpstreamIDPProvider) oidProvidersManager := manager.NewManager(
healthMux,
dynamicJWKSProvider,
dynamicUpstreamIDPProvider,
kubeClient.CoreV1().Secrets(serverInstallationNamespace),
)
startControllers( startControllers(
ctx, ctx,

View File

@ -9,6 +9,7 @@ import (
"sync" "sync"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/auth" "go.pinniped.dev/internal/oidc/auth"
@ -32,18 +33,25 @@ type Manager struct {
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
secretsClient corev1client.SecretInterface
} }
// NewManager returns an empty Manager. // NewManager returns an empty Manager.
// nextHandler will be invoked for any requests that could not be handled by this manager's providers. // nextHandler will be invoked for any requests that could not be handled by this manager's providers.
// dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data.
// idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs.
func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager { func NewManager(
nextHandler http.Handler,
dynamicJWKSProvider jwks.DynamicJWKSProvider,
idpListGetter oidc.IDPListGetter,
secretsClient corev1client.SecretInterface,
) *Manager {
return &Manager{ return &Manager{
providerHandlers: make(map[string]http.Handler), providerHandlers: make(map[string]http.Handler),
nextHandler: nextHandler, nextHandler: nextHandler,
dynamicJWKSProvider: dynamicJWKSProvider, dynamicJWKSProvider: dynamicJWKSProvider,
idpListGetter: idpListGetter, idpListGetter: idpListGetter,
secretsClient: secretsClient,
} }
} }
@ -63,15 +71,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
m.providerHandlers = make(map[string]http.Handler) m.providerHandlers = make(map[string]http.Handler)
for _, incomingProvider := range oidcProviders { for _, incomingProvider := range oidcProviders {
wellKnownURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.WellKnownEndpointPath issuer := incomingProvider.Issuer()
m.providerHandlers[wellKnownURL] = discovery.NewHandler(incomingProvider.Issuer()) issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath()
jwksURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.JWKSEndpointPath fositeHMACSecretForThisProvider := []byte("some secret - must have at least 32 bytes") // TODO replace this secret
m.providerHandlers[jwksURL] = jwks.NewHandler(incomingProvider.Issuer(), m.dynamicJWKSProvider)
// Use NullStorage for the authorize endpoint because we do not actually want to store anything until // Use NullStorage for the authorize endpoint because we do not actually want to store anything until
// the upstream callback endpoint is called later. // the upstream callback endpoint is called later.
oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, incomingProvider.Issuer(), []byte("some secret - must have at least 32 bytes")) // TODO replace this secret oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, fositeHMACSecretForThisProvider)
// For all the other endpoints, make another oauth helper with exactly the same settings except use real storage.
oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, fositeHMACSecretForThisProvider)
// TODO use different codecs for the state and the cookie, because: // TODO use different codecs for the state and the cookie, because:
// 1. we would like to state to have an embedded expiration date while the cookie does not need that // 1. we would like to state to have an embedded expiration date while the cookie does not need that
@ -82,11 +92,14 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
var encoder = securecookie.New(encoderHashKey, encoderBlockKey) var encoder = securecookie.New(encoderHashKey, encoderBlockKey)
encoder.SetSerializer(securecookie.JSONEncoder{}) encoder.SetSerializer(securecookie.JSONEncoder{})
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer)
m.providerHandlers[authURL] = auth.NewHandler(
incomingProvider.Issuer(), m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider)
m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler(
issuer,
m.idpListGetter, m.idpListGetter,
oauthHelper, oauthHelperWithNullStorage,
csrftoken.Generate, csrftoken.Generate,
pkce.Generate, pkce.Generate,
nonce.Generate, nonce.Generate,
@ -94,16 +107,15 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
encoder, encoder,
) )
callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler(
m.providerHandlers[callbackURL] = callback.NewHandler(
m.idpListGetter, m.idpListGetter,
oauthHelper, oauthHelperWithKubeStorage,
encoder, encoder,
encoder, encoder,
incomingProvider.Issuer()+oidc.CallbackEndpointPath, issuer+oidc.CallbackEndpointPath,
) )
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer)
} }
} }

View File

@ -4,6 +4,7 @@
package manager package manager
import ( import (
"context"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -15,6 +16,7 @@ import (
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"k8s.io/client-go/kubernetes/fake"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
@ -22,6 +24,9 @@ import (
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
) )
func TestManager(t *testing.T) { func TestManager(t *testing.T) {
@ -32,6 +37,7 @@ func TestManager(t *testing.T) {
nextHandler http.HandlerFunc nextHandler http.HandlerFunc
fallbackHandlerWasCalled bool fallbackHandlerWasCalled bool
dynamicJWKSProvider jwks.DynamicJWKSProvider dynamicJWKSProvider jwks.DynamicJWKSProvider
kubeClient *fake.Clientset
) )
const ( const (
@ -66,7 +72,7 @@ func TestManager(t *testing.T) {
r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer) r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer)
} }
requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) { requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix))
@ -81,18 +87,58 @@ func TestManager(t *testing.T) {
"actual location %s did not start with expected prefix %s", "actual location %s did not start with expected prefix %s",
actualLocation, expectedRedirectLocationPrefix, actualLocation, expectedRedirectLocationPrefix,
) )
parsedLocation, err := url.Parse(actualLocation)
r.NoError(err)
redirectStateParam := parsedLocation.Query().Get("state")
r.NotEmpty(redirectStateParam)
cookieValueAndDirectivesSplit := strings.SplitN(recorder.Header().Get("Set-Cookie"), ";", 2)
r.Len(cookieValueAndDirectivesSplit, 2)
cookieKeyValueSplit := strings.Split(cookieValueAndDirectivesSplit[0], "=")
r.Len(cookieKeyValueSplit, 2)
csrfCookieName := cookieKeyValueSplit[0]
r.Equal("__Host-pinniped-csrf", csrfCookieName)
csrfCookieValue := cookieKeyValueSplit[1]
r.NotEmpty(csrfCookieValue)
// Return the important parts of the response so we can use them in our next request to the callback endpoint
return csrfCookieValue, redirectStateParam
} }
requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix string) { requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix, csrfCookieValue string) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.CallbackEndpointPath+requestURLSuffix)) numberOfKubeActionsBeforeThisRequest := len(kubeClient.Actions())
getRequest := newGetRequest(requestIssuer + oidc.CallbackEndpointPath + requestURLSuffix)
getRequest.AddCookie(&http.Cookie{
Name: "__Host-pinniped-csrf",
Value: csrfCookieValue,
})
subject.ServeHTTP(recorder, getRequest)
r.False(fallbackHandlerWasCalled) r.False(fallbackHandlerWasCalled)
// Minimal check to ensure that the right endpoint was called - when we don't send a CSRF // Check just enough of the response to ensure that we wired up the callback endpoint correctly.
// cookie to the callback endpoint, the callback endpoint responds with a 403. // The endpoint's own unit tests cover everything else.
r.Equal(http.StatusForbidden, recorder.Code) r.Equal(http.StatusFound, recorder.Code)
actualLocation := recorder.Header().Get("Location")
r.True(
strings.HasPrefix(actualLocation, downstreamRedirectURL),
"actual location %s did not start with expected prefix %s",
actualLocation, downstreamRedirectURL,
)
parsedLocation, err := url.Parse(actualLocation)
r.NoError(err)
actualLocationQueryParams := parsedLocation.Query()
r.Contains(actualLocationQueryParams, "code")
r.Equal("openid", actualLocationQueryParams.Get("scope"))
r.Equal("some-state-value-that-is-32-byte", actualLocationQueryParams.Get("state"))
// Make sure that we wired up the callback endpoint to use kube storage for fosite sessions.
r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3,
"did not perform any kube actions during the callback request, but should have")
} }
requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) {
@ -126,9 +172,22 @@ func TestManager(t *testing.T) {
ClientID: "test-client-id", ClientID: "test-client-id",
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
Scopes: []string{"test-scope"}, Scopes: []string{"test-scope"},
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
return oidctypes.Token{},
map[string]interface{}{
"iss": "https://some-issuer.com",
"sub": "some-subject",
"username": "test-username",
"groups": "test-group1",
},
nil
},
}) })
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) kubeClient = fake.NewSimpleClientset()
secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, secretsClient)
}) })
when("given no providers via SetProviders()", func() { when("given no providers via SetProviders()", func() {
@ -191,19 +250,20 @@ func TestManager(t *testing.T) {
// Hostnames are case-insensitive, so test that we can handle that. // Hostnames are case-insensitive, so test that we can handle that.
requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) csrfCookieValue, upstreamStateParam :=
requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
callbackRequestParams := "?" + url.Values{ callbackRequestParams := "?" + url.Values{
"code": []string{"some-code"}, "code": []string{"some-fake-code"},
"state": []string{"some-state-value"}, "state": []string{upstreamStateParam},
}.Encode() }.Encode()
requireCallbackRequestToBeHandled(issuer1, callbackRequestParams) requireCallbackRequestToBeHandled(issuer1, callbackRequestParams, csrfCookieValue)
requireCallbackRequestToBeHandled(issuer2, callbackRequestParams) requireCallbackRequestToBeHandled(issuer2, callbackRequestParams, csrfCookieValue)
// // Hostnames are case-insensitive, so test that we can handle that. // // Hostnames are case-insensitive, so test that we can handle that.
requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams) requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams, csrfCookieValue)
requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams) requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams, csrfCookieValue)
} }
when("given some valid providers via SetProviders()", func() { when("given some valid providers via SetProviders()", func() {