Several small refactors related to OIDC providers

This commit is contained in:
Ryan Richard 2020-10-08 11:28:21 -07:00
parent da00fc708f
commit 8b7d96f42c
10 changed files with 134 additions and 120 deletions

View File

@ -24,7 +24,7 @@ import (
"go.pinniped.dev/internal/controller/supervisorconfig" "go.pinniped.dev/internal/controller/supervisorconfig"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/downward" "go.pinniped.dev/internal/downward"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/manager"
) )
const ( const (
@ -61,7 +61,7 @@ func waitForSignal() os.Signal {
func startControllers( func startControllers(
ctx context.Context, ctx context.Context,
issuerProvider *provider.Manager, issuerProvider *manager.Manager,
pinnipedClient pinnipedclientset.Interface, pinnipedClient pinnipedclientset.Interface,
pinnipedInformers pinnipedinformers.SharedInformerFactory, pinnipedInformers pinnipedinformers.SharedInformerFactory,
) { ) {
@ -114,7 +114,7 @@ func run(serverInstallationNamespace string) error {
pinnipedinformers.WithNamespace(serverInstallationNamespace), pinnipedinformers.WithNamespace(serverInstallationNamespace),
) )
oidProvidersManager := provider.NewManager(http.NotFoundHandler()) oidProvidersManager := manager.NewManager(http.NotFoundHandler())
startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers) startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers)
//nolint: gosec // Intentionally binding to all network interfaces. //nolint: gosec // Intentionally binding to all network interfaces.

View File

@ -6,8 +6,8 @@ package supervisorconfig
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"strings" "go.pinniped.dev/internal/multierror"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
@ -73,10 +73,10 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
issuerCounts := make(map[string]int) issuerCounts := make(map[string]int)
for _, opc := range all { for _, opc := range all {
issuerCounts[opc.Spec.Issuer] = issuerCounts[opc.Spec.Issuer] + 1 issuerCounts[opc.Spec.Issuer]++
} }
errs := newMultiError() errs := multierror.New()
oidcProviders := make([]*provider.OIDCProvider, 0) oidcProviders := make([]*provider.OIDCProvider, 0)
for _, opc := range all { for _, opc := range all {
@ -88,36 +88,21 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
configv1alpha1.DuplicateOIDCProviderStatus, configv1alpha1.DuplicateOIDCProviderStatus,
"Duplicate issuer", "Duplicate issuer",
); err != nil { ); err != nil {
errs.add(fmt.Errorf("could not update status: %w", err)) errs.Add(fmt.Errorf("could not update status: %w", err))
} }
continue continue
} }
issuerURL, err := url.Parse(opc.Spec.Issuer) oidcProvider, err := provider.NewOIDCProvider(opc.Spec.Issuer)
if err != nil { if err != nil {
if err := c.updateStatus( if err := c.updateStatus(
ctx.Context, ctx.Context,
opc.Namespace, opc.Namespace,
opc.Name, opc.Name,
configv1alpha1.InvalidOIDCProviderStatus, configv1alpha1.InvalidOIDCProviderStatus,
"Invalid issuer URL: "+err.Error(), "Invalid: "+err.Error(),
); err != nil { ); err != nil {
errs.add(fmt.Errorf("could not update status: %w", err)) errs.Add(fmt.Errorf("could not update status: %w", err))
}
continue
}
oidcProvider := &provider.OIDCProvider{Issuer: issuerURL}
err = oidcProvider.Validate()
if err != nil {
if err := c.updateStatus(
ctx.Context,
opc.Namespace,
opc.Name,
configv1alpha1.InvalidOIDCProviderStatus,
"Invalid issuer: "+err.Error(),
); err != nil {
errs.add(fmt.Errorf("could not update status: %w", err))
} }
continue continue
} }
@ -130,14 +115,13 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
configv1alpha1.SuccessOIDCProviderStatus, configv1alpha1.SuccessOIDCProviderStatus,
"Provider successfully created", "Provider successfully created",
); err != nil { ); err != nil {
// errs.add(fmt.Errorf("could not update status: %w", err)) errs.Add(fmt.Errorf("could not update status: %w", err))
return fmt.Errorf("could not update status: %w", err)
} }
} }
c.providerSetter.SetProviders(oidcProviders...) c.providerSetter.SetProviders(oidcProviders...)
return errs.errOrNil() return errs.ErrOrNil()
} }
func (c *oidcProviderConfigWatcherController) updateStatus( func (c *oidcProviderConfigWatcherController) updateStatus(
@ -171,33 +155,3 @@ func (c *oidcProviderConfigWatcherController) updateStatus(
return err return err
}) })
} }
type multiError []error
func newMultiError() multiError {
return make([]error, 0)
}
func (m *multiError) add(err error) {
*m = append(*m, err)
}
func (m multiError) len() int {
return len(m)
}
func (m multiError) Error() string {
sb := strings.Builder{}
fmt.Fprintf(&sb, "%d errors:", m.len())
for _, err := range m {
fmt.Fprintf(&sb, "\n- %s", err.Error())
}
return sb.String()
}
func (m multiError) errOrNil() error {
if m.len() > 0 {
return m
}
return nil
}

View File

@ -0,0 +1,39 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package multierror
import (
"fmt"
"strings"
)
type multiError []error
func New() multiError { //nolint:golint // returning a private type for encapsulation purposes
return make([]error, 0)
}
func (m *multiError) Add(err error) {
*m = append(*m, err)
}
func (m multiError) len() int {
return len(m)
}
func (m multiError) Error() string {
sb := strings.Builder{}
_, _ = fmt.Fprintf(&sb, "%d errors:", m.len())
for _, err := range m {
_, _ = fmt.Fprintf(&sb, "\n- %s", err.Error())
}
return sb.String()
}
func (m multiError) ErrOrNil() error {
if m.len() > 0 {
return m
}
return nil
}

View File

@ -6,8 +6,9 @@ package discovery
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"go.pinniped.dev/internal/oidc"
) )
// Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the // Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the
@ -50,9 +51,9 @@ func New(issuerURL string) http.Handler {
oidcConfig := Metadata{ oidcConfig := Metadata{
Issuer: issuerURL, Issuer: issuerURL,
AuthorizationEndpoint: fmt.Sprintf("%s/oauth2/v0/auth", issuerURL), AuthorizationEndpoint: issuerURL + oidc.AuthorizationEndpointPath,
TokenEndpoint: fmt.Sprintf("%s/oauth2/v0/token", issuerURL), TokenEndpoint: issuerURL + oidc.TokenEndpointPath,
JWKSURI: fmt.Sprintf("%s/jwks.json", issuerURL), JWKSURI: issuerURL + oidc.JWKSEndpointPath,
ResponseTypesSupported: []string{"code"}, ResponseTypesSupported: []string{"code"},
SubjectTypesSupported: []string{"public"}, SubjectTypesSupported: []string{"public"},
IDTokenSigningAlgValuesSupported: []string{"RS256"}, IDTokenSigningAlgValuesSupported: []string{"RS256"},

View File

@ -30,13 +30,13 @@ func TestDiscovery(t *testing.T) {
name: "happy path", name: "happy path",
issuer: "https://some-issuer.com/some/path", issuer: "https://some-issuer.com/some/path",
method: http.MethodGet, method: http.MethodGet,
path: "/some/path" + oidc.WellKnownURLPath, path: "/some/path" + oidc.WellKnownEndpointPath,
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantContentType: "application/json", wantContentType: "application/json",
wantBody: &Metadata{ wantBody: &Metadata{
Issuer: "https://some-issuer.com/some/path", Issuer: "https://some-issuer.com/some/path",
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/v0/auth", AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/authorize",
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/v0/token", TokenEndpoint: "https://some-issuer.com/some/path/oauth2/token",
JWKSURI: "https://some-issuer.com/some/path/jwks.json", JWKSURI: "https://some-issuer.com/some/path/jwks.json",
ResponseTypesSupported: []string{"code"}, ResponseTypesSupported: []string{"code"},
SubjectTypesSupported: []string{"public"}, SubjectTypesSupported: []string{"public"},
@ -51,7 +51,7 @@ func TestDiscovery(t *testing.T) {
name: "bad method", name: "bad method",
issuer: "https://some-issuer.com", issuer: "https://some-issuer.com",
method: http.MethodPost, method: http.MethodPost,
path: oidc.WellKnownURLPath, path: oidc.WellKnownEndpointPath,
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
wantBody: map[string]string{ wantBody: map[string]string{
"error": "Method not allowed (try GET)", "error": "Method not allowed (try GET)",

View File

@ -5,5 +5,8 @@
package oidc package oidc
const ( const (
WellKnownURLPath = "/.well-known/openid-configuration" WellKnownEndpointPath = "/.well-known/openid-configuration"
AuthorizationEndpointPath = "/oauth2/authorize"
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
JWKSEndpointPath = "/jwks.json"
) )

View File

@ -1,11 +1,10 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package provider package manager
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
@ -13,6 +12,7 @@ import (
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/provider"
) )
// Manager can manage multiple active OIDC providers. It acts as a request router for them. // Manager can manage multiple active OIDC providers. It acts as a request router for them.
@ -24,21 +24,17 @@ type Manager struct {
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
} }
// New returns an empty Manager. // NewManager returns an empty Manager.
// nextHandler will be invoked for any requests that could not be handled by this manager's providers. // nextHandler will be invoked for any requests that could not be handled by this manager's providers.
func NewManager(nextHandler http.Handler) *Manager { func NewManager(nextHandler http.Handler) *Manager {
return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler} return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler}
} }
type providerHandler struct { type providerHandler struct {
provider *OIDCProvider provider *provider.OIDCProvider
discoveryHandler http.Handler discoveryHandler http.Handler
} }
func (h *providerHandler) Issuer() *url.URL {
return h.provider.Issuer
}
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string // SetProviders adds or updates all the given providerHandlers using each provider's issuer string
// as the name of the provider to decide if it is an add or update operation. // as the name of the provider to decide if it is an add or update operation.
// //
@ -47,12 +43,12 @@ func (h *providerHandler) Issuer() *url.URL {
// //
// This method assumes that all of the OIDCProvider arguments have already been validated // This method assumes that all of the OIDCProvider arguments have already been validated
// by someone else before they are passed to this method. // by someone else before they are passed to this method.
func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) { func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
// Add all of the incoming providers. // Add all of the incoming providers.
for _, incomingProvider := range oidcProviders { for _, incomingProvider := range oidcProviders {
issuerString := incomingProvider.Issuer.String() issuerString := incomingProvider.Issuer()
c.providerHandlers[issuerString] = &providerHandler{ c.providerHandlers[issuerString] = &providerHandler{
provider: incomingProvider, provider: incomingProvider,
discoveryHandler: discovery.New(issuerString), discoveryHandler: discovery.New(issuerString),
@ -70,9 +66,9 @@ func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
// ServeHTTP implements the http.Handler interface. // ServeHTTP implements the http.Handler interface.
func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) { func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
providerHandler := c.findProviderHandlerByIssuerURL(req.Host, req.URL.Path) providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path)
if providerHandler != nil { if providerHandler != nil {
if req.URL.Path == providerHandler.Issuer().Path+oidc.WellKnownURLPath { if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath {
providerHandler.discoveryHandler.ServeHTTP(resp, req) providerHandler.discoveryHandler.ServeHTTP(resp, req)
return // handled! return // handled!
} }
@ -94,20 +90,20 @@ func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
c.nextHandler.ServeHTTP(resp, req) c.nextHandler.ServeHTTP(resp, req)
} }
func (c *Manager) findProviderHandlerByIssuerURL(host, path string) *providerHandler { func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler {
for _, providerHandler := range c.providerHandlers { for _, providerHandler := range c.providerHandlers {
pi := providerHandler.Issuer()
// TODO do we need to compare scheme? not sure how to get it from the http.Request object // TODO do we need to compare scheme? not sure how to get it from the http.Request object
if host == pi.Host && strings.HasPrefix(path, pi.Path) { // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) {
return providerHandler return providerHandler
} }
} }
return nil return nil
} }
func findIssuerInListOfProviders(issuer string, oidcProviders []*OIDCProvider) bool { func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool {
for _, provider := range oidcProviders { for _, oidcProvider := range oidcProviders {
if provider.Issuer.String() == issuer { if oidcProvider.Issuer() == issuer {
return true return true
} }
} }

View File

@ -4,6 +4,7 @@
package provider package provider
import ( import (
"fmt"
"net/url" "net/url"
"strings" "strings"
@ -12,39 +13,68 @@ import (
// OIDCProvider represents all of the settings and state for an OIDC provider. // OIDCProvider represents all of the settings and state for an OIDC provider.
type OIDCProvider struct { type OIDCProvider struct {
Issuer *url.URL issuer string
issuerHost string
issuerPath string
} }
// Validate returns an error if there is anything wrong with the provider settings, or func NewOIDCProvider(issuer string) (*OIDCProvider, error) {
// returns nil if there is nothing wrong with the settings. p := OIDCProvider{issuer: issuer}
func (p *OIDCProvider) Validate() error { err := p.validate()
if p.Issuer == nil { if err != nil {
return constable.Error(`provider must have an issuer`) return nil, err
}
return &p, nil
} }
if p.Issuer.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(p.Issuer.Scheme) { func (p *OIDCProvider) validate() error {
if p.issuer == "" {
return constable.Error("provider must have an issuer")
}
issuerURL, err := url.Parse(p.issuer)
if err != nil {
return fmt.Errorf("could not parse issuer as URL: %w", err)
}
if issuerURL.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(issuerURL.Scheme) {
return constable.Error(`issuer must have "https" scheme`) return constable.Error(`issuer must have "https" scheme`)
} }
if p.Issuer.User != nil { if issuerURL.User != nil {
return constable.Error(`issuer must not have username or password`) return constable.Error(`issuer must not have username or password`)
} }
if strings.HasSuffix(p.Issuer.Path, "/") { if strings.HasSuffix(issuerURL.Path, "/") {
return constable.Error(`issuer must not have trailing slash in path`) return constable.Error(`issuer must not have trailing slash in path`)
} }
if p.Issuer.RawQuery != "" { if issuerURL.RawQuery != "" {
return constable.Error(`issuer must not have query`) return constable.Error(`issuer must not have query`)
} }
if p.Issuer.Fragment != "" { if issuerURL.Fragment != "" {
return constable.Error(`issuer must not have fragment`) return constable.Error(`issuer must not have fragment`)
} }
p.issuerHost = issuerURL.Host
p.issuerPath = issuerURL.Path
return nil return nil
} }
func (p *OIDCProvider) Issuer() string {
return p.issuer
}
func (p *OIDCProvider) IssuerHost() string {
return p.issuerHost
}
func (p *OIDCProvider) IssuerPath() string {
return p.issuerPath
}
func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool { func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool {
return scheme != "http" return scheme != "http"
} }

View File

@ -4,7 +4,6 @@
package provider package provider
import ( import (
"net/url"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -13,63 +12,62 @@ import (
func TestOIDCProviderValidations(t *testing.T) { func TestOIDCProviderValidations(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
issuer *url.URL issuer string
wantError string wantError string
}{ }{
{ {
name: "provider must have an issuer", name: "provider must have an issuer",
issuer: nil, issuer: "",
wantError: "provider must have an issuer", wantError: "provider must have an issuer",
}, },
{ {
name: "no scheme", name: "no scheme",
issuer: must(url.Parse("tuna.com")), issuer: "tuna.com",
wantError: `issuer must have "https" scheme`, wantError: `issuer must have "https" scheme`,
}, },
{ {
name: "bad scheme", name: "bad scheme",
issuer: must(url.Parse("ftp://tuna.com")), issuer: "ftp://tuna.com",
wantError: `issuer must have "https" scheme`, wantError: `issuer must have "https" scheme`,
}, },
{ {
name: "fragment", name: "fragment",
issuer: must(url.Parse("https://tuna.com/fish#some-frag")), issuer: "https://tuna.com/fish#some-frag",
wantError: `issuer must not have fragment`, wantError: `issuer must not have fragment`,
}, },
{ {
name: "query", name: "query",
issuer: must(url.Parse("https://tuna.com?some=query")), issuer: "https://tuna.com?some=query",
wantError: `issuer must not have query`, wantError: `issuer must not have query`,
}, },
{ {
name: "username", name: "username",
issuer: must(url.Parse("https://username@tuna.com")), issuer: "https://username@tuna.com",
wantError: `issuer must not have username or password`, wantError: `issuer must not have username or password`,
}, },
{ {
name: "password", name: "password",
issuer: must(url.Parse("https://username:password@tuna.com")), issuer: "https://username:password@tuna.com",
wantError: `issuer must not have username or password`, wantError: `issuer must not have username or password`,
}, },
{ {
name: "without path", name: "without path",
issuer: must(url.Parse("https://tuna.com")), issuer: "https://tuna.com",
}, },
{ {
name: "with path", name: "with path",
issuer: must(url.Parse("https://tuna.com/fish/marlin")), issuer: "https://tuna.com/fish/marlin",
}, },
{ {
name: "trailing slash in path", name: "trailing slash in path",
issuer: must(url.Parse("https://tuna.com/")), issuer: "https://tuna.com/",
wantError: `issuer must not have trailing slash in path`, wantError: `issuer must not have trailing slash in path`,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p := OIDCProvider{Issuer: tt.issuer} _, err := NewOIDCProvider(tt.issuer)
err := p.Validate()
if tt.wantError != "" { if tt.wantError != "" {
require.EqualError(t, err, tt.wantError) require.EqualError(t, err, tt.wantError)
} else { } else {
@ -78,10 +76,3 @@ func TestOIDCProviderValidations(t *testing.T) {
}) })
} }
} }
func must(u *url.URL, err error) *url.URL {
if err != nil {
panic(err)
}
return u
}

View File

@ -175,8 +175,8 @@ func requireWellKnownEndpointIsWorking(t *testing.T, issuerName string) {
// Check that the response matches our expectations. // Check that the response matches our expectations.
expectedResultTemplate := here.Doc(`{ expectedResultTemplate := here.Doc(`{
"issuer": "%s", "issuer": "%s",
"authorization_endpoint": "%s/oauth2/v0/auth", "authorization_endpoint": "%s/oauth2/authorize",
"token_endpoint": "%s/oauth2/v0/token", "token_endpoint": "%s/oauth2/token",
"token_endpoint_auth_methods_supported": ["client_secret_basic"], "token_endpoint_auth_methods_supported": ["client_secret_basic"],
"token_endpoint_auth_signing_alg_values_supported": ["RS256"], "token_endpoint_auth_signing_alg_values_supported": ["RS256"],
"jwks_uri": "%s/jwks.json", "jwks_uri": "%s/jwks.json",