From 95093ab0af59ca5b14038ca9069b3324863b1002 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Wed, 2 Dec 2020 17:39:45 -0800 Subject: [PATCH] Use kube storage for the supervisor callback endpoint's fosite sessions --- cmd/pinniped-supervisor/main.go | 7 +- internal/oidc/provider/manager/manager.go | 42 +++++---- .../oidc/provider/manager/manager_test.go | 88 ++++++++++++++++--- 3 files changed, 107 insertions(+), 30 deletions(-) diff --git a/cmd/pinniped-supervisor/main.go b/cmd/pinniped-supervisor/main.go index d2bfc7f5..31f5dff8 100644 --- a/cmd/pinniped-supervisor/main.go +++ b/cmd/pinniped-supervisor/main.go @@ -196,7 +196,12 @@ func run(serverInstallationNamespace string, cfg *supervisor.Config) error { dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider() // OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux. - oidProvidersManager := manager.NewManager(healthMux, dynamicJWKSProvider, dynamicUpstreamIDPProvider) + oidProvidersManager := manager.NewManager( + healthMux, + dynamicJWKSProvider, + dynamicUpstreamIDPProvider, + kubeClient.CoreV1().Secrets(serverInstallationNamespace), + ) startControllers( ctx, diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 6bac2c60..b4238273 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/gorilla/securecookie" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/auth" @@ -32,18 +33,25 @@ type Manager struct { nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs + secretsClient corev1client.SecretInterface } // NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. -func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager { +func NewManager( + nextHandler http.Handler, + dynamicJWKSProvider jwks.DynamicJWKSProvider, + idpListGetter oidc.IDPListGetter, + secretsClient corev1client.SecretInterface, +) *Manager { return &Manager{ providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler, dynamicJWKSProvider: dynamicJWKSProvider, idpListGetter: idpListGetter, + secretsClient: secretsClient, } } @@ -63,15 +71,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { m.providerHandlers = make(map[string]http.Handler) for _, incomingProvider := range oidcProviders { - wellKnownURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.WellKnownEndpointPath - m.providerHandlers[wellKnownURL] = discovery.NewHandler(incomingProvider.Issuer()) + issuer := incomingProvider.Issuer() + issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() - jwksURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.JWKSEndpointPath - m.providerHandlers[jwksURL] = jwks.NewHandler(incomingProvider.Issuer(), m.dynamicJWKSProvider) + fositeHMACSecretForThisProvider := []byte("some secret - must have at least 32 bytes") // TODO replace this secret // Use NullStorage for the authorize endpoint because we do not actually want to store anything until // the upstream callback endpoint is called later. - oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, incomingProvider.Issuer(), []byte("some secret - must have at least 32 bytes")) // TODO replace this secret + oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, fositeHMACSecretForThisProvider) + + // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage. + oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, fositeHMACSecretForThisProvider) // TODO use different codecs for the state and the cookie, because: // 1. we would like to state to have an embedded expiration date while the cookie does not need that @@ -82,11 +92,14 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { var encoder = securecookie.New(encoderHashKey, encoderBlockKey) encoder.SetSerializer(securecookie.JSONEncoder{}) - authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath - m.providerHandlers[authURL] = auth.NewHandler( - incomingProvider.Issuer(), + m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer) + + m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider) + + m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler( + issuer, m.idpListGetter, - oauthHelper, + oauthHelperWithNullStorage, csrftoken.Generate, pkce.Generate, nonce.Generate, @@ -94,16 +107,15 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { encoder, ) - callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath - m.providerHandlers[callbackURL] = callback.NewHandler( + m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler( m.idpListGetter, - oauthHelper, + oauthHelperWithKubeStorage, encoder, encoder, - incomingProvider.Issuer()+oidc.CallbackEndpointPath, + issuer+oidc.CallbackEndpointPath, ) - plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) + plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index 44ac6398..a3f8090d 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -4,6 +4,7 @@ package manager import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -15,6 +16,7 @@ import ( "github.com/sclevine/spec" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" + "k8s.io/client-go/kubernetes/fake" "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/oidc" @@ -22,6 +24,9 @@ import ( "go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" ) func TestManager(t *testing.T) { @@ -32,6 +37,7 @@ func TestManager(t *testing.T) { nextHandler http.HandlerFunc fallbackHandlerWasCalled bool dynamicJWKSProvider jwks.DynamicJWKSProvider + kubeClient *fake.Clientset ) const ( @@ -66,7 +72,7 @@ func TestManager(t *testing.T) { r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer) } - requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) { + requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) { recorder := httptest.NewRecorder() subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) @@ -81,18 +87,58 @@ func TestManager(t *testing.T) { "actual location %s did not start with expected prefix %s", actualLocation, expectedRedirectLocationPrefix, ) + + parsedLocation, err := url.Parse(actualLocation) + r.NoError(err) + redirectStateParam := parsedLocation.Query().Get("state") + r.NotEmpty(redirectStateParam) + + cookieValueAndDirectivesSplit := strings.SplitN(recorder.Header().Get("Set-Cookie"), ";", 2) + r.Len(cookieValueAndDirectivesSplit, 2) + cookieKeyValueSplit := strings.Split(cookieValueAndDirectivesSplit[0], "=") + r.Len(cookieKeyValueSplit, 2) + csrfCookieName := cookieKeyValueSplit[0] + r.Equal("__Host-pinniped-csrf", csrfCookieName) + csrfCookieValue := cookieKeyValueSplit[1] + r.NotEmpty(csrfCookieValue) + + // Return the important parts of the response so we can use them in our next request to the callback endpoint + return csrfCookieValue, redirectStateParam } - requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix string) { + requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix, csrfCookieValue string) { recorder := httptest.NewRecorder() - subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.CallbackEndpointPath+requestURLSuffix)) + numberOfKubeActionsBeforeThisRequest := len(kubeClient.Actions()) + + getRequest := newGetRequest(requestIssuer + oidc.CallbackEndpointPath + requestURLSuffix) + getRequest.AddCookie(&http.Cookie{ + Name: "__Host-pinniped-csrf", + Value: csrfCookieValue, + }) + subject.ServeHTTP(recorder, getRequest) r.False(fallbackHandlerWasCalled) - // Minimal check to ensure that the right endpoint was called - when we don't send a CSRF - // cookie to the callback endpoint, the callback endpoint responds with a 403. - r.Equal(http.StatusForbidden, recorder.Code) + // Check just enough of the response to ensure that we wired up the callback endpoint correctly. + // The endpoint's own unit tests cover everything else. + r.Equal(http.StatusFound, recorder.Code) + actualLocation := recorder.Header().Get("Location") + r.True( + strings.HasPrefix(actualLocation, downstreamRedirectURL), + "actual location %s did not start with expected prefix %s", + actualLocation, downstreamRedirectURL, + ) + parsedLocation, err := url.Parse(actualLocation) + r.NoError(err) + actualLocationQueryParams := parsedLocation.Query() + r.Contains(actualLocationQueryParams, "code") + r.Equal("openid", actualLocationQueryParams.Get("scope")) + r.Equal("some-state-value-that-is-32-byte", actualLocationQueryParams.Get("state")) + + // Make sure that we wired up the callback endpoint to use kube storage for fosite sessions. + r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3, + "did not perform any kube actions during the callback request, but should have") } requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { @@ -126,9 +172,22 @@ func TestManager(t *testing.T) { ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, Scopes: []string{"test-scope"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + return oidctypes.Token{}, + map[string]interface{}{ + "iss": "https://some-issuer.com", + "sub": "some-subject", + "username": "test-username", + "groups": "test-group1", + }, + nil + }, }) - subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) + kubeClient = fake.NewSimpleClientset() + secretsClient := kubeClient.CoreV1().Secrets("some-namespace") + + subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, secretsClient) }) when("given no providers via SetProviders()", func() { @@ -191,19 +250,20 @@ func TestManager(t *testing.T) { // Hostnames are case-insensitive, so test that we can handle that. requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) - requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + csrfCookieValue, upstreamStateParam := + requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) callbackRequestParams := "?" + url.Values{ - "code": []string{"some-code"}, - "state": []string{"some-state-value"}, + "code": []string{"some-fake-code"}, + "state": []string{upstreamStateParam}, }.Encode() - requireCallbackRequestToBeHandled(issuer1, callbackRequestParams) - requireCallbackRequestToBeHandled(issuer2, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer1, callbackRequestParams, csrfCookieValue) + requireCallbackRequestToBeHandled(issuer2, callbackRequestParams, csrfCookieValue) // // Hostnames are case-insensitive, so test that we can handle that. - requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams) - requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams, csrfCookieValue) + requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams, csrfCookieValue) } when("given some valid providers via SetProviders()", func() {