diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 0c11b5be..0626c48c 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -5,7 +5,6 @@ package manager import ( "net/http" - "strings" "sync" "k8s.io/klog/v2" @@ -20,19 +19,15 @@ import ( // It is thread-safe. type Manager struct { mu sync.RWMutex - providerHandlers map[string]*providerHandler // map of issuer name to providerHandler - nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request + providers []*provider.OIDCProvider + providerHandlers map[string]http.Handler // map of all routes for all providers + nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request } // NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. func NewManager(nextHandler http.Handler) *Manager { - return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler} -} - -type providerHandler struct { - provider *provider.OIDCProvider - discoveryHandler http.Handler + return &Manager{providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler} } // SetProviders adds or updates all the given providerHandlers using each provider's issuer string @@ -43,69 +38,40 @@ type providerHandler struct { // // This method assumes that all of the OIDCProvider arguments have already been validated // by someone else before they are passed to this method. -func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { - c.mu.Lock() - defer c.mu.Unlock() - // Add all of the incoming providers. +func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { + m.mu.Lock() + defer m.mu.Unlock() + + m.providers = oidcProviders + m.providerHandlers = make(map[string]http.Handler) + for _, incomingProvider := range oidcProviders { - issuerString := incomingProvider.Issuer() - c.providerHandlers[issuerString] = &providerHandler{ - provider: incomingProvider, - discoveryHandler: discovery.New(issuerString), - } - klog.InfoS("oidc provider manager added or updated issuer", "issuer", issuerString) - } - // Remove any providers that we previously handled but no longer exist. - for issuerKey := range c.providerHandlers { - if !findIssuerInListOfProviders(issuerKey, oidcProviders) { - delete(c.providerHandlers, issuerKey) - klog.InfoS("oidc provider manager removed issuer", "issuer", issuerKey) - } + m.providerHandlers[incomingProvider.IssuerHost()+"/"+incomingProvider.IssuerPath()+oidc.WellKnownEndpointPath] = discovery.New(incomingProvider.Issuer()) + klog.InfoS("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) } } // ServeHTTP implements the http.Handler interface. -func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) { - providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path) - if providerHandler != nil { - if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath { - providerHandler.discoveryHandler.ServeHTTP(resp, req) - return // handled! - } - klog.InfoS( - "oidc provider manager found issuer but could not handle request", - "method", req.Method, - "host", req.Host, - "path", req.URL.Path, - ) - } else { - klog.InfoS( - "oidc provider manager could not find issuer to handle request", - "method", req.Method, - "host", req.Host, - "path", req.URL.Path, - ) +func (m *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + requestHandler := m.findHandler(req) + + klog.InfoS( + "oidc provider manager examining request", + "method", req.Method, + "host", req.Host, + "path", req.URL.Path, + "foundMatchingIssuer", requestHandler != nil, + ) + + if requestHandler == nil { + requestHandler = m.nextHandler // couldn't find an issuer to handle the request } - // Didn't know how to handle this request, so send it along the chain for further processing. - c.nextHandler.ServeHTTP(resp, req) + requestHandler.ServeHTTP(resp, req) } -func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler { - for _, providerHandler := range c.providerHandlers { - // TODO do we need to compare scheme? not sure how to get it from the http.Request object - // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux - if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) { - return providerHandler - } - } - return nil -} +func (m *Manager) findHandler(req *http.Request) http.Handler { + m.mu.RLock() + defer m.mu.RUnlock() -func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool { - for _, oidcProvider := range oidcProviders { - if oidcProvider.Issuer() == issuer { - return true - } - } - return false + return m.providerHandlers[req.Host+"/"+req.URL.Path] } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go new file mode 100644 index 00000000..c877f3ac --- /dev/null +++ b/internal/oidc/provider/manager/manager_test.go @@ -0,0 +1,123 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package manager + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/sclevine/spec" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/discovery" + "go.pinniped.dev/internal/oidc/provider" +) + +func TestManager(t *testing.T) { + spec.Run(t, "ServeHTTP", func(t *testing.T, when spec.G, it spec.S) { + var r *require.Assertions + var subject *Manager + var nextHandler http.HandlerFunc + var fallbackHandlerWasCalled bool + + newGetRequest := func(url string) *http.Request { + return httptest.NewRequest(http.MethodGet, url, nil) + } + + requireDiscoveryRequestToBeHandled := func(issuer, requestURLSuffix string) { + recorder := httptest.NewRecorder() + + subject.ServeHTTP(recorder, newGetRequest(issuer+oidc.WellKnownEndpointPath+requestURLSuffix)) + + r.Equal(http.StatusOK, recorder.Code) + responseBody, err := ioutil.ReadAll(recorder.Body) + r.NoError(err) + parsedDiscoveryResult := discovery.Metadata{} + err = json.Unmarshal(responseBody, &parsedDiscoveryResult) + r.NoError(err) + + r.Equal(issuer, parsedDiscoveryResult.Issuer) + } + + it.Before(func() { + r = require.New(t) + nextHandler = func(http.ResponseWriter, *http.Request) { + fallbackHandlerWasCalled = true + } + subject = NewManager(nextHandler) + }) + + when("given no providers", func() { + it("sends all requests to the nextHandler", func() { + r.False(fallbackHandlerWasCalled) + subject.ServeHTTP(httptest.NewRecorder(), newGetRequest("/anything")) + r.True(fallbackHandlerWasCalled) + }) + }) + + when("given some valid providers", func() { + issuer1 := "https://example.com/some/path" + issuer2 := "https://example.com/some/path/more/deeply/nested/path" // note that this is a sub-path of the other issuer url + + it.Before(func() { + p1, err := provider.NewOIDCProvider(issuer1) + r.NoError(err) + p2, err := provider.NewOIDCProvider(issuer2) + r.NoError(err) + subject.SetProviders(p1, p2) + }) + + it("sends all non-matching host requests to the nextHandler", func() { + r.False(fallbackHandlerWasCalled) + url := strings.ReplaceAll(issuer1+oidc.WellKnownEndpointPath, "example.com", "wrong-host.com") + subject.ServeHTTP(httptest.NewRecorder(), newGetRequest(url)) + r.True(fallbackHandlerWasCalled) + }) + + it("sends all non-matching path requests to the nextHandler", func() { + r.False(fallbackHandlerWasCalled) + subject.ServeHTTP(httptest.NewRecorder(), newGetRequest("https://example.com/path-does-not-match-any-provider")) + r.True(fallbackHandlerWasCalled) + }) + + it("sends requests which match the issuer prefix but do not match any of that provider's known paths to the nextHandler", func() { + r.False(fallbackHandlerWasCalled) + subject.ServeHTTP(httptest.NewRecorder(), newGetRequest(issuer1+"/unhandled-sub-path")) + r.True(fallbackHandlerWasCalled) + }) + + it("routes matching requests to the appropriate provider", func() { + requireDiscoveryRequestToBeHandled(issuer1, "") + requireDiscoveryRequestToBeHandled(issuer2, "") + requireDiscoveryRequestToBeHandled(issuer2, "?some=query") + r.False(fallbackHandlerWasCalled) + }) + }) + + when("given the same valid providers in reverse order", func() { + issuer1 := "https://example.com/some/path" + issuer2 := "https://example.com/some/path/more/deeply/nested/path" + + it.Before(func() { + p1, err := provider.NewOIDCProvider(issuer1) + r.NoError(err) + p2, err := provider.NewOIDCProvider(issuer2) + r.NoError(err) + subject.SetProviders(p2, p1) + }) + + it("still routes matching requests to the appropriate provider", func() { + requireDiscoveryRequestToBeHandled(issuer1, "") + requireDiscoveryRequestToBeHandled(issuer2, "") + requireDiscoveryRequestToBeHandled(issuer2, "?some=query") + r.False(fallbackHandlerWasCalled) + }) + }) + }) +}