Refactor provider.Manager
- And also handle when an issuer's path is a subpath of another issuer Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
8b7d96f42c
commit
05141592f8
@ -5,7 +5,6 @@ package manager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"k8s.io/klog/v2"
|
"k8s.io/klog/v2"
|
||||||
@ -20,19 +19,15 @@ import (
|
|||||||
// It is thread-safe.
|
// It is thread-safe.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
providerHandlers map[string]*providerHandler // map of issuer name to providerHandler
|
providers []*provider.OIDCProvider
|
||||||
|
providerHandlers map[string]http.Handler // map of all routes for all providers
|
||||||
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager returns an empty Manager.
|
// NewManager returns an empty Manager.
|
||||||
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
||||||
func NewManager(nextHandler http.Handler) *Manager {
|
func NewManager(nextHandler http.Handler) *Manager {
|
||||||
return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler}
|
return &Manager{providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler}
|
||||||
}
|
|
||||||
|
|
||||||
type providerHandler struct {
|
|
||||||
provider *provider.OIDCProvider
|
|
||||||
discoveryHandler http.Handler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string
|
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string
|
||||||
@ -43,69 +38,40 @@ type providerHandler struct {
|
|||||||
//
|
//
|
||||||
// This method assumes that all of the OIDCProvider arguments have already been validated
|
// This method assumes that all of the OIDCProvider arguments have already been validated
|
||||||
// by someone else before they are passed to this method.
|
// by someone else before they are passed to this method.
|
||||||
func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
||||||
c.mu.Lock()
|
m.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
// Add all of the incoming providers.
|
|
||||||
|
m.providers = oidcProviders
|
||||||
|
m.providerHandlers = make(map[string]http.Handler)
|
||||||
|
|
||||||
for _, incomingProvider := range oidcProviders {
|
for _, incomingProvider := range oidcProviders {
|
||||||
issuerString := incomingProvider.Issuer()
|
m.providerHandlers[incomingProvider.IssuerHost()+"/"+incomingProvider.IssuerPath()+oidc.WellKnownEndpointPath] = discovery.New(incomingProvider.Issuer())
|
||||||
c.providerHandlers[issuerString] = &providerHandler{
|
klog.InfoS("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
|
||||||
provider: incomingProvider,
|
|
||||||
discoveryHandler: discovery.New(issuerString),
|
|
||||||
}
|
|
||||||
klog.InfoS("oidc provider manager added or updated issuer", "issuer", issuerString)
|
|
||||||
}
|
|
||||||
// Remove any providers that we previously handled but no longer exist.
|
|
||||||
for issuerKey := range c.providerHandlers {
|
|
||||||
if !findIssuerInListOfProviders(issuerKey, oidcProviders) {
|
|
||||||
delete(c.providerHandlers, issuerKey)
|
|
||||||
klog.InfoS("oidc provider manager removed issuer", "issuer", issuerKey)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP implements the http.Handler interface.
|
// ServeHTTP implements the http.Handler interface.
|
||||||
func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
func (m *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path)
|
requestHandler := m.findHandler(req)
|
||||||
if providerHandler != nil {
|
|
||||||
if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath {
|
|
||||||
providerHandler.discoveryHandler.ServeHTTP(resp, req)
|
|
||||||
return // handled!
|
|
||||||
}
|
|
||||||
klog.InfoS(
|
klog.InfoS(
|
||||||
"oidc provider manager found issuer but could not handle request",
|
"oidc provider manager examining request",
|
||||||
"method", req.Method,
|
|
||||||
"host", req.Host,
|
|
||||||
"path", req.URL.Path,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
klog.InfoS(
|
|
||||||
"oidc provider manager could not find issuer to handle request",
|
|
||||||
"method", req.Method,
|
"method", req.Method,
|
||||||
"host", req.Host,
|
"host", req.Host,
|
||||||
"path", req.URL.Path,
|
"path", req.URL.Path,
|
||||||
|
"foundMatchingIssuer", requestHandler != nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if requestHandler == nil {
|
||||||
|
requestHandler = m.nextHandler // couldn't find an issuer to handle the request
|
||||||
}
|
}
|
||||||
// Didn't know how to handle this request, so send it along the chain for further processing.
|
requestHandler.ServeHTTP(resp, req)
|
||||||
c.nextHandler.ServeHTTP(resp, req)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler {
|
func (m *Manager) findHandler(req *http.Request) http.Handler {
|
||||||
for _, providerHandler := range c.providerHandlers {
|
m.mu.RLock()
|
||||||
// TODO do we need to compare scheme? not sure how to get it from the http.Request object
|
defer m.mu.RUnlock()
|
||||||
// TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
|
|
||||||
if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) {
|
|
||||||
return providerHandler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool {
|
return m.providerHandlers[req.Host+"/"+req.URL.Path]
|
||||||
for _, oidcProvider := range oidcProviders {
|
|
||||||
if oidcProvider.Issuer() == issuer {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
123
internal/oidc/provider/manager/manager_test.go
Normal file
123
internal/oidc/provider/manager/manager_test.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sclevine/spec"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"go.pinniped.dev/internal/oidc"
|
||||||
|
"go.pinniped.dev/internal/oidc/discovery"
|
||||||
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager(t *testing.T) {
|
||||||
|
spec.Run(t, "ServeHTTP", func(t *testing.T, when spec.G, it spec.S) {
|
||||||
|
var r *require.Assertions
|
||||||
|
var subject *Manager
|
||||||
|
var nextHandler http.HandlerFunc
|
||||||
|
var fallbackHandlerWasCalled bool
|
||||||
|
|
||||||
|
newGetRequest := func(url string) *http.Request {
|
||||||
|
return httptest.NewRequest(http.MethodGet, url, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
requireDiscoveryRequestToBeHandled := func(issuer, requestURLSuffix string) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
subject.ServeHTTP(recorder, newGetRequest(issuer+oidc.WellKnownEndpointPath+requestURLSuffix))
|
||||||
|
|
||||||
|
r.Equal(http.StatusOK, recorder.Code)
|
||||||
|
responseBody, err := ioutil.ReadAll(recorder.Body)
|
||||||
|
r.NoError(err)
|
||||||
|
parsedDiscoveryResult := discovery.Metadata{}
|
||||||
|
err = json.Unmarshal(responseBody, &parsedDiscoveryResult)
|
||||||
|
r.NoError(err)
|
||||||
|
|
||||||
|
r.Equal(issuer, parsedDiscoveryResult.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
it.Before(func() {
|
||||||
|
r = require.New(t)
|
||||||
|
nextHandler = func(http.ResponseWriter, *http.Request) {
|
||||||
|
fallbackHandlerWasCalled = true
|
||||||
|
}
|
||||||
|
subject = NewManager(nextHandler)
|
||||||
|
})
|
||||||
|
|
||||||
|
when("given no providers", func() {
|
||||||
|
it("sends all requests to the nextHandler", func() {
|
||||||
|
r.False(fallbackHandlerWasCalled)
|
||||||
|
subject.ServeHTTP(httptest.NewRecorder(), newGetRequest("/anything"))
|
||||||
|
r.True(fallbackHandlerWasCalled)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
when("given some valid providers", func() {
|
||||||
|
issuer1 := "https://example.com/some/path"
|
||||||
|
issuer2 := "https://example.com/some/path/more/deeply/nested/path" // note that this is a sub-path of the other issuer url
|
||||||
|
|
||||||
|
it.Before(func() {
|
||||||
|
p1, err := provider.NewOIDCProvider(issuer1)
|
||||||
|
r.NoError(err)
|
||||||
|
p2, err := provider.NewOIDCProvider(issuer2)
|
||||||
|
r.NoError(err)
|
||||||
|
subject.SetProviders(p1, p2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("sends all non-matching host requests to the nextHandler", func() {
|
||||||
|
r.False(fallbackHandlerWasCalled)
|
||||||
|
url := strings.ReplaceAll(issuer1+oidc.WellKnownEndpointPath, "example.com", "wrong-host.com")
|
||||||
|
subject.ServeHTTP(httptest.NewRecorder(), newGetRequest(url))
|
||||||
|
r.True(fallbackHandlerWasCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("sends all non-matching path requests to the nextHandler", func() {
|
||||||
|
r.False(fallbackHandlerWasCalled)
|
||||||
|
subject.ServeHTTP(httptest.NewRecorder(), newGetRequest("https://example.com/path-does-not-match-any-provider"))
|
||||||
|
r.True(fallbackHandlerWasCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("sends requests which match the issuer prefix but do not match any of that provider's known paths to the nextHandler", func() {
|
||||||
|
r.False(fallbackHandlerWasCalled)
|
||||||
|
subject.ServeHTTP(httptest.NewRecorder(), newGetRequest(issuer1+"/unhandled-sub-path"))
|
||||||
|
r.True(fallbackHandlerWasCalled)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("routes matching requests to the appropriate provider", func() {
|
||||||
|
requireDiscoveryRequestToBeHandled(issuer1, "")
|
||||||
|
requireDiscoveryRequestToBeHandled(issuer2, "")
|
||||||
|
requireDiscoveryRequestToBeHandled(issuer2, "?some=query")
|
||||||
|
r.False(fallbackHandlerWasCalled)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
when("given the same valid providers in reverse order", func() {
|
||||||
|
issuer1 := "https://example.com/some/path"
|
||||||
|
issuer2 := "https://example.com/some/path/more/deeply/nested/path"
|
||||||
|
|
||||||
|
it.Before(func() {
|
||||||
|
p1, err := provider.NewOIDCProvider(issuer1)
|
||||||
|
r.NoError(err)
|
||||||
|
p2, err := provider.NewOIDCProvider(issuer2)
|
||||||
|
r.NoError(err)
|
||||||
|
subject.SetProviders(p2, p1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it("still routes matching requests to the appropriate provider", func() {
|
||||||
|
requireDiscoveryRequestToBeHandled(issuer1, "")
|
||||||
|
requireDiscoveryRequestToBeHandled(issuer2, "")
|
||||||
|
requireDiscoveryRequestToBeHandled(issuer2, "?some=query")
|
||||||
|
r.False(fallbackHandlerWasCalled)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user