From 8b7d96f42cb0192e6c105087d9ec682b555014df Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Thu, 8 Oct 2020 11:28:21 -0700 Subject: [PATCH] Several small refactors related to OIDC providers --- cmd/pinniped-supervisor/main.go | 6 +- .../oidcproviderconfig_watcher.go | 66 +++---------------- internal/multierror/multierror.go | 39 +++++++++++ internal/oidc/discovery/discovery.go | 9 +-- internal/oidc/discovery/discovery_test.go | 8 +-- internal/oidc/oidc.go | 5 +- .../oidc/provider/{ => manager}/manager.go | 32 ++++----- internal/oidc/provider/oidcprovider.go | 52 +++++++++++---- internal/oidc/provider/oidcprovider_test.go | 33 ++++------ test/integration/supervisor_discovery_test.go | 4 +- 10 files changed, 134 insertions(+), 120 deletions(-) create mode 100644 internal/multierror/multierror.go rename internal/oidc/provider/{ => manager}/manager.go (78%) diff --git a/cmd/pinniped-supervisor/main.go b/cmd/pinniped-supervisor/main.go index dbbf5cfc..fe04cd7f 100644 --- a/cmd/pinniped-supervisor/main.go +++ b/cmd/pinniped-supervisor/main.go @@ -24,7 +24,7 @@ import ( "go.pinniped.dev/internal/controller/supervisorconfig" "go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/downward" - "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidc/provider/manager" ) const ( @@ -61,7 +61,7 @@ func waitForSignal() os.Signal { func startControllers( ctx context.Context, - issuerProvider *provider.Manager, + issuerProvider *manager.Manager, pinnipedClient pinnipedclientset.Interface, pinnipedInformers pinnipedinformers.SharedInformerFactory, ) { @@ -114,7 +114,7 @@ func run(serverInstallationNamespace string) error { pinnipedinformers.WithNamespace(serverInstallationNamespace), ) - oidProvidersManager := provider.NewManager(http.NotFoundHandler()) + oidProvidersManager := manager.NewManager(http.NotFoundHandler()) startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers) //nolint: gosec // Intentionally binding to all network interfaces. diff --git a/internal/controller/supervisorconfig/oidcproviderconfig_watcher.go b/internal/controller/supervisorconfig/oidcproviderconfig_watcher.go index cd827c94..3f246a85 100644 --- a/internal/controller/supervisorconfig/oidcproviderconfig_watcher.go +++ b/internal/controller/supervisorconfig/oidcproviderconfig_watcher.go @@ -6,8 +6,8 @@ package supervisorconfig import ( "context" "fmt" - "net/url" - "strings" + + "go.pinniped.dev/internal/multierror" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" @@ -73,10 +73,10 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er issuerCounts := make(map[string]int) for _, opc := range all { - issuerCounts[opc.Spec.Issuer] = issuerCounts[opc.Spec.Issuer] + 1 + issuerCounts[opc.Spec.Issuer]++ } - errs := newMultiError() + errs := multierror.New() oidcProviders := make([]*provider.OIDCProvider, 0) for _, opc := range all { @@ -88,36 +88,21 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er configv1alpha1.DuplicateOIDCProviderStatus, "Duplicate issuer", ); err != nil { - errs.add(fmt.Errorf("could not update status: %w", err)) + errs.Add(fmt.Errorf("could not update status: %w", err)) } continue } - issuerURL, err := url.Parse(opc.Spec.Issuer) + oidcProvider, err := provider.NewOIDCProvider(opc.Spec.Issuer) if err != nil { if err := c.updateStatus( ctx.Context, opc.Namespace, opc.Name, configv1alpha1.InvalidOIDCProviderStatus, - "Invalid issuer URL: "+err.Error(), + "Invalid: "+err.Error(), ); err != nil { - errs.add(fmt.Errorf("could not update status: %w", err)) - } - continue - } - - oidcProvider := &provider.OIDCProvider{Issuer: issuerURL} - err = oidcProvider.Validate() - if err != nil { - if err := c.updateStatus( - ctx.Context, - opc.Namespace, - opc.Name, - configv1alpha1.InvalidOIDCProviderStatus, - "Invalid issuer: "+err.Error(), - ); err != nil { - errs.add(fmt.Errorf("could not update status: %w", err)) + errs.Add(fmt.Errorf("could not update status: %w", err)) } continue } @@ -130,14 +115,13 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er configv1alpha1.SuccessOIDCProviderStatus, "Provider successfully created", ); err != nil { - // errs.add(fmt.Errorf("could not update status: %w", err)) - return fmt.Errorf("could not update status: %w", err) + errs.Add(fmt.Errorf("could not update status: %w", err)) } } c.providerSetter.SetProviders(oidcProviders...) - return errs.errOrNil() + return errs.ErrOrNil() } func (c *oidcProviderConfigWatcherController) updateStatus( @@ -171,33 +155,3 @@ func (c *oidcProviderConfigWatcherController) updateStatus( return err }) } - -type multiError []error - -func newMultiError() multiError { - return make([]error, 0) -} - -func (m *multiError) add(err error) { - *m = append(*m, err) -} - -func (m multiError) len() int { - return len(m) -} - -func (m multiError) Error() string { - sb := strings.Builder{} - fmt.Fprintf(&sb, "%d errors:", m.len()) - for _, err := range m { - fmt.Fprintf(&sb, "\n- %s", err.Error()) - } - return sb.String() -} - -func (m multiError) errOrNil() error { - if m.len() > 0 { - return m - } - return nil -} diff --git a/internal/multierror/multierror.go b/internal/multierror/multierror.go new file mode 100644 index 00000000..f9d9f2f8 --- /dev/null +++ b/internal/multierror/multierror.go @@ -0,0 +1,39 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package multierror + +import ( + "fmt" + "strings" +) + +type multiError []error + +func New() multiError { //nolint:golint // returning a private type for encapsulation purposes + return make([]error, 0) +} + +func (m *multiError) Add(err error) { + *m = append(*m, err) +} + +func (m multiError) len() int { + return len(m) +} + +func (m multiError) Error() string { + sb := strings.Builder{} + _, _ = fmt.Fprintf(&sb, "%d errors:", m.len()) + for _, err := range m { + _, _ = fmt.Fprintf(&sb, "\n- %s", err.Error()) + } + return sb.String() +} + +func (m multiError) ErrOrNil() error { + if m.len() > 0 { + return m + } + return nil +} diff --git a/internal/oidc/discovery/discovery.go b/internal/oidc/discovery/discovery.go index 84ca1c77..c919e150 100644 --- a/internal/oidc/discovery/discovery.go +++ b/internal/oidc/discovery/discovery.go @@ -6,8 +6,9 @@ package discovery import ( "encoding/json" - "fmt" "net/http" + + "go.pinniped.dev/internal/oidc" ) // Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the @@ -50,9 +51,9 @@ func New(issuerURL string) http.Handler { oidcConfig := Metadata{ Issuer: issuerURL, - AuthorizationEndpoint: fmt.Sprintf("%s/oauth2/v0/auth", issuerURL), - TokenEndpoint: fmt.Sprintf("%s/oauth2/v0/token", issuerURL), - JWKSURI: fmt.Sprintf("%s/jwks.json", issuerURL), + AuthorizationEndpoint: issuerURL + oidc.AuthorizationEndpointPath, + TokenEndpoint: issuerURL + oidc.TokenEndpointPath, + JWKSURI: issuerURL + oidc.JWKSEndpointPath, ResponseTypesSupported: []string{"code"}, SubjectTypesSupported: []string{"public"}, IDTokenSigningAlgValuesSupported: []string{"RS256"}, diff --git a/internal/oidc/discovery/discovery_test.go b/internal/oidc/discovery/discovery_test.go index 14f2d9b6..e21b3c4a 100644 --- a/internal/oidc/discovery/discovery_test.go +++ b/internal/oidc/discovery/discovery_test.go @@ -30,13 +30,13 @@ func TestDiscovery(t *testing.T) { name: "happy path", issuer: "https://some-issuer.com/some/path", method: http.MethodGet, - path: "/some/path" + oidc.WellKnownURLPath, + path: "/some/path" + oidc.WellKnownEndpointPath, wantStatus: http.StatusOK, wantContentType: "application/json", wantBody: &Metadata{ Issuer: "https://some-issuer.com/some/path", - AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/v0/auth", - TokenEndpoint: "https://some-issuer.com/some/path/oauth2/v0/token", + AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/authorize", + TokenEndpoint: "https://some-issuer.com/some/path/oauth2/token", JWKSURI: "https://some-issuer.com/some/path/jwks.json", ResponseTypesSupported: []string{"code"}, SubjectTypesSupported: []string{"public"}, @@ -51,7 +51,7 @@ func TestDiscovery(t *testing.T) { name: "bad method", issuer: "https://some-issuer.com", method: http.MethodPost, - path: oidc.WellKnownURLPath, + path: oidc.WellKnownEndpointPath, wantStatus: http.StatusMethodNotAllowed, wantBody: map[string]string{ "error": "Method not allowed (try GET)", diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 795c3d38..d78f199c 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -5,5 +5,8 @@ package oidc const ( - WellKnownURLPath = "/.well-known/openid-configuration" + WellKnownEndpointPath = "/.well-known/openid-configuration" + AuthorizationEndpointPath = "/oauth2/authorize" + TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential + JWKSEndpointPath = "/jwks.json" ) diff --git a/internal/oidc/provider/manager.go b/internal/oidc/provider/manager/manager.go similarity index 78% rename from internal/oidc/provider/manager.go rename to internal/oidc/provider/manager/manager.go index a50a7095..0c11b5be 100644 --- a/internal/oidc/provider/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -1,11 +1,10 @@ // Copyright 2020 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package provider +package manager import ( "net/http" - "net/url" "strings" "sync" @@ -13,6 +12,7 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/discovery" + "go.pinniped.dev/internal/oidc/provider" ) // Manager can manage multiple active OIDC providers. It acts as a request router for them. @@ -24,21 +24,17 @@ type Manager struct { nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request } -// New returns an empty Manager. +// NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. func NewManager(nextHandler http.Handler) *Manager { return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler} } type providerHandler struct { - provider *OIDCProvider + provider *provider.OIDCProvider discoveryHandler http.Handler } -func (h *providerHandler) Issuer() *url.URL { - return h.provider.Issuer -} - // SetProviders adds or updates all the given providerHandlers using each provider's issuer string // as the name of the provider to decide if it is an add or update operation. // @@ -47,12 +43,12 @@ func (h *providerHandler) Issuer() *url.URL { // // This method assumes that all of the OIDCProvider arguments have already been validated // by someone else before they are passed to this method. -func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) { +func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { c.mu.Lock() defer c.mu.Unlock() // Add all of the incoming providers. for _, incomingProvider := range oidcProviders { - issuerString := incomingProvider.Issuer.String() + issuerString := incomingProvider.Issuer() c.providerHandlers[issuerString] = &providerHandler{ provider: incomingProvider, discoveryHandler: discovery.New(issuerString), @@ -70,9 +66,9 @@ func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) { // ServeHTTP implements the http.Handler interface. func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) { - providerHandler := c.findProviderHandlerByIssuerURL(req.Host, req.URL.Path) + providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path) if providerHandler != nil { - if req.URL.Path == providerHandler.Issuer().Path+oidc.WellKnownURLPath { + if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath { providerHandler.discoveryHandler.ServeHTTP(resp, req) return // handled! } @@ -94,20 +90,20 @@ func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) { c.nextHandler.ServeHTTP(resp, req) } -func (c *Manager) findProviderHandlerByIssuerURL(host, path string) *providerHandler { +func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler { for _, providerHandler := range c.providerHandlers { - pi := providerHandler.Issuer() // TODO do we need to compare scheme? not sure how to get it from the http.Request object - if host == pi.Host && strings.HasPrefix(path, pi.Path) { // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux + // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux + if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) { return providerHandler } } return nil } -func findIssuerInListOfProviders(issuer string, oidcProviders []*OIDCProvider) bool { - for _, provider := range oidcProviders { - if provider.Issuer.String() == issuer { +func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool { + for _, oidcProvider := range oidcProviders { + if oidcProvider.Issuer() == issuer { return true } } diff --git a/internal/oidc/provider/oidcprovider.go b/internal/oidc/provider/oidcprovider.go index c7c24de4..cc427167 100644 --- a/internal/oidc/provider/oidcprovider.go +++ b/internal/oidc/provider/oidcprovider.go @@ -4,6 +4,7 @@ package provider import ( + "fmt" "net/url" "strings" @@ -12,39 +13,68 @@ import ( // OIDCProvider represents all of the settings and state for an OIDC provider. type OIDCProvider struct { - Issuer *url.URL + issuer string + issuerHost string + issuerPath string } -// Validate returns an error if there is anything wrong with the provider settings, or -// returns nil if there is nothing wrong with the settings. -func (p *OIDCProvider) Validate() error { - if p.Issuer == nil { - return constable.Error(`provider must have an issuer`) +func NewOIDCProvider(issuer string) (*OIDCProvider, error) { + p := OIDCProvider{issuer: issuer} + err := p.validate() + if err != nil { + return nil, err + } + return &p, nil +} + +func (p *OIDCProvider) validate() error { + if p.issuer == "" { + return constable.Error("provider must have an issuer") } - if p.Issuer.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(p.Issuer.Scheme) { + issuerURL, err := url.Parse(p.issuer) + if err != nil { + return fmt.Errorf("could not parse issuer as URL: %w", err) + } + + if issuerURL.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(issuerURL.Scheme) { return constable.Error(`issuer must have "https" scheme`) } - if p.Issuer.User != nil { + if issuerURL.User != nil { return constable.Error(`issuer must not have username or password`) } - if strings.HasSuffix(p.Issuer.Path, "/") { + if strings.HasSuffix(issuerURL.Path, "/") { return constable.Error(`issuer must not have trailing slash in path`) } - if p.Issuer.RawQuery != "" { + if issuerURL.RawQuery != "" { return constable.Error(`issuer must not have query`) } - if p.Issuer.Fragment != "" { + if issuerURL.Fragment != "" { return constable.Error(`issuer must not have fragment`) } + p.issuerHost = issuerURL.Host + p.issuerPath = issuerURL.Path + return nil } +func (p *OIDCProvider) Issuer() string { + return p.issuer +} + +func (p *OIDCProvider) IssuerHost() string { + return p.issuerHost +} + +func (p *OIDCProvider) IssuerPath() string { + return p.issuerPath +} + func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool { return scheme != "http" } diff --git a/internal/oidc/provider/oidcprovider_test.go b/internal/oidc/provider/oidcprovider_test.go index 2da0e3a8..81204e28 100644 --- a/internal/oidc/provider/oidcprovider_test.go +++ b/internal/oidc/provider/oidcprovider_test.go @@ -4,7 +4,6 @@ package provider import ( - "net/url" "testing" "github.com/stretchr/testify/require" @@ -13,63 +12,62 @@ import ( func TestOIDCProviderValidations(t *testing.T) { tests := []struct { name string - issuer *url.URL + issuer string wantError string }{ { name: "provider must have an issuer", - issuer: nil, + issuer: "", wantError: "provider must have an issuer", }, { name: "no scheme", - issuer: must(url.Parse("tuna.com")), + issuer: "tuna.com", wantError: `issuer must have "https" scheme`, }, { name: "bad scheme", - issuer: must(url.Parse("ftp://tuna.com")), + issuer: "ftp://tuna.com", wantError: `issuer must have "https" scheme`, }, { name: "fragment", - issuer: must(url.Parse("https://tuna.com/fish#some-frag")), + issuer: "https://tuna.com/fish#some-frag", wantError: `issuer must not have fragment`, }, { name: "query", - issuer: must(url.Parse("https://tuna.com?some=query")), + issuer: "https://tuna.com?some=query", wantError: `issuer must not have query`, }, { name: "username", - issuer: must(url.Parse("https://username@tuna.com")), + issuer: "https://username@tuna.com", wantError: `issuer must not have username or password`, }, { name: "password", - issuer: must(url.Parse("https://username:password@tuna.com")), + issuer: "https://username:password@tuna.com", wantError: `issuer must not have username or password`, }, { name: "without path", - issuer: must(url.Parse("https://tuna.com")), + issuer: "https://tuna.com", }, { name: "with path", - issuer: must(url.Parse("https://tuna.com/fish/marlin")), + issuer: "https://tuna.com/fish/marlin", }, { name: "trailing slash in path", - issuer: must(url.Parse("https://tuna.com/")), + issuer: "https://tuna.com/", wantError: `issuer must not have trailing slash in path`, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - p := OIDCProvider{Issuer: tt.issuer} - err := p.Validate() + _, err := NewOIDCProvider(tt.issuer) if tt.wantError != "" { require.EqualError(t, err, tt.wantError) } else { @@ -78,10 +76,3 @@ func TestOIDCProviderValidations(t *testing.T) { }) } } - -func must(u *url.URL, err error) *url.URL { - if err != nil { - panic(err) - } - return u -} diff --git a/test/integration/supervisor_discovery_test.go b/test/integration/supervisor_discovery_test.go index 94f3cd9d..77dcaebf 100644 --- a/test/integration/supervisor_discovery_test.go +++ b/test/integration/supervisor_discovery_test.go @@ -175,8 +175,8 @@ func requireWellKnownEndpointIsWorking(t *testing.T, issuerName string) { // Check that the response matches our expectations. expectedResultTemplate := here.Doc(`{ "issuer": "%s", - "authorization_endpoint": "%s/oauth2/v0/auth", - "token_endpoint": "%s/oauth2/v0/token", + "authorization_endpoint": "%s/oauth2/authorize", + "token_endpoint": "%s/oauth2/token", "token_endpoint_auth_methods_supported": ["client_secret_basic"], "token_endpoint_auth_signing_alg_values_supported": ["RS256"], "jwks_uri": "%s/jwks.json",