Several small refactors related to OIDC providers

This commit is contained in:
Ryan Richard 2020-10-08 11:28:21 -07:00
parent da00fc708f
commit 8b7d96f42c
10 changed files with 134 additions and 120 deletions

View File

@ -24,7 +24,7 @@ import (
"go.pinniped.dev/internal/controller/supervisorconfig"
"go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/downward"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/manager"
)
const (
@ -61,7 +61,7 @@ func waitForSignal() os.Signal {
func startControllers(
ctx context.Context,
issuerProvider *provider.Manager,
issuerProvider *manager.Manager,
pinnipedClient pinnipedclientset.Interface,
pinnipedInformers pinnipedinformers.SharedInformerFactory,
) {
@ -114,7 +114,7 @@ func run(serverInstallationNamespace string) error {
pinnipedinformers.WithNamespace(serverInstallationNamespace),
)
oidProvidersManager := provider.NewManager(http.NotFoundHandler())
oidProvidersManager := manager.NewManager(http.NotFoundHandler())
startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers)
//nolint: gosec // Intentionally binding to all network interfaces.

View File

@ -6,8 +6,8 @@ package supervisorconfig
import (
"context"
"fmt"
"net/url"
"strings"
"go.pinniped.dev/internal/multierror"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
@ -73,10 +73,10 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
issuerCounts := make(map[string]int)
for _, opc := range all {
issuerCounts[opc.Spec.Issuer] = issuerCounts[opc.Spec.Issuer] + 1
issuerCounts[opc.Spec.Issuer]++
}
errs := newMultiError()
errs := multierror.New()
oidcProviders := make([]*provider.OIDCProvider, 0)
for _, opc := range all {
@ -88,36 +88,21 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
configv1alpha1.DuplicateOIDCProviderStatus,
"Duplicate issuer",
); err != nil {
errs.add(fmt.Errorf("could not update status: %w", err))
errs.Add(fmt.Errorf("could not update status: %w", err))
}
continue
}
issuerURL, err := url.Parse(opc.Spec.Issuer)
oidcProvider, err := provider.NewOIDCProvider(opc.Spec.Issuer)
if err != nil {
if err := c.updateStatus(
ctx.Context,
opc.Namespace,
opc.Name,
configv1alpha1.InvalidOIDCProviderStatus,
"Invalid issuer URL: "+err.Error(),
"Invalid: "+err.Error(),
); err != nil {
errs.add(fmt.Errorf("could not update status: %w", err))
}
continue
}
oidcProvider := &provider.OIDCProvider{Issuer: issuerURL}
err = oidcProvider.Validate()
if err != nil {
if err := c.updateStatus(
ctx.Context,
opc.Namespace,
opc.Name,
configv1alpha1.InvalidOIDCProviderStatus,
"Invalid issuer: "+err.Error(),
); err != nil {
errs.add(fmt.Errorf("could not update status: %w", err))
errs.Add(fmt.Errorf("could not update status: %w", err))
}
continue
}
@ -130,14 +115,13 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
configv1alpha1.SuccessOIDCProviderStatus,
"Provider successfully created",
); err != nil {
// errs.add(fmt.Errorf("could not update status: %w", err))
return fmt.Errorf("could not update status: %w", err)
errs.Add(fmt.Errorf("could not update status: %w", err))
}
}
c.providerSetter.SetProviders(oidcProviders...)
return errs.errOrNil()
return errs.ErrOrNil()
}
func (c *oidcProviderConfigWatcherController) updateStatus(
@ -171,33 +155,3 @@ func (c *oidcProviderConfigWatcherController) updateStatus(
return err
})
}
type multiError []error
func newMultiError() multiError {
return make([]error, 0)
}
func (m *multiError) add(err error) {
*m = append(*m, err)
}
func (m multiError) len() int {
return len(m)
}
func (m multiError) Error() string {
sb := strings.Builder{}
fmt.Fprintf(&sb, "%d errors:", m.len())
for _, err := range m {
fmt.Fprintf(&sb, "\n- %s", err.Error())
}
return sb.String()
}
func (m multiError) errOrNil() error {
if m.len() > 0 {
return m
}
return nil
}

View File

@ -0,0 +1,39 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package multierror
import (
"fmt"
"strings"
)
type multiError []error
func New() multiError { //nolint:golint // returning a private type for encapsulation purposes
return make([]error, 0)
}
func (m *multiError) Add(err error) {
*m = append(*m, err)
}
func (m multiError) len() int {
return len(m)
}
func (m multiError) Error() string {
sb := strings.Builder{}
_, _ = fmt.Fprintf(&sb, "%d errors:", m.len())
for _, err := range m {
_, _ = fmt.Fprintf(&sb, "\n- %s", err.Error())
}
return sb.String()
}
func (m multiError) ErrOrNil() error {
if m.len() > 0 {
return m
}
return nil
}

View File

@ -6,8 +6,9 @@ package discovery
import (
"encoding/json"
"fmt"
"net/http"
"go.pinniped.dev/internal/oidc"
)
// Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the
@ -50,9 +51,9 @@ func New(issuerURL string) http.Handler {
oidcConfig := Metadata{
Issuer: issuerURL,
AuthorizationEndpoint: fmt.Sprintf("%s/oauth2/v0/auth", issuerURL),
TokenEndpoint: fmt.Sprintf("%s/oauth2/v0/token", issuerURL),
JWKSURI: fmt.Sprintf("%s/jwks.json", issuerURL),
AuthorizationEndpoint: issuerURL + oidc.AuthorizationEndpointPath,
TokenEndpoint: issuerURL + oidc.TokenEndpointPath,
JWKSURI: issuerURL + oidc.JWKSEndpointPath,
ResponseTypesSupported: []string{"code"},
SubjectTypesSupported: []string{"public"},
IDTokenSigningAlgValuesSupported: []string{"RS256"},

View File

@ -30,13 +30,13 @@ func TestDiscovery(t *testing.T) {
name: "happy path",
issuer: "https://some-issuer.com/some/path",
method: http.MethodGet,
path: "/some/path" + oidc.WellKnownURLPath,
path: "/some/path" + oidc.WellKnownEndpointPath,
wantStatus: http.StatusOK,
wantContentType: "application/json",
wantBody: &Metadata{
Issuer: "https://some-issuer.com/some/path",
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/v0/auth",
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/v0/token",
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/authorize",
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/token",
JWKSURI: "https://some-issuer.com/some/path/jwks.json",
ResponseTypesSupported: []string{"code"},
SubjectTypesSupported: []string{"public"},
@ -51,7 +51,7 @@ func TestDiscovery(t *testing.T) {
name: "bad method",
issuer: "https://some-issuer.com",
method: http.MethodPost,
path: oidc.WellKnownURLPath,
path: oidc.WellKnownEndpointPath,
wantStatus: http.StatusMethodNotAllowed,
wantBody: map[string]string{
"error": "Method not allowed (try GET)",

View File

@ -5,5 +5,8 @@
package oidc
const (
WellKnownURLPath = "/.well-known/openid-configuration"
WellKnownEndpointPath = "/.well-known/openid-configuration"
AuthorizationEndpointPath = "/oauth2/authorize"
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
JWKSEndpointPath = "/jwks.json"
)

View File

@ -1,11 +1,10 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package provider
package manager
import (
"net/http"
"net/url"
"strings"
"sync"
@ -13,6 +12,7 @@ import (
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/provider"
)
// Manager can manage multiple active OIDC providers. It acts as a request router for them.
@ -24,21 +24,17 @@ type Manager struct {
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
}
// New returns an empty Manager.
// NewManager returns an empty Manager.
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
func NewManager(nextHandler http.Handler) *Manager {
return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler}
}
type providerHandler struct {
provider *OIDCProvider
provider *provider.OIDCProvider
discoveryHandler http.Handler
}
func (h *providerHandler) Issuer() *url.URL {
return h.provider.Issuer
}
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string
// as the name of the provider to decide if it is an add or update operation.
//
@ -47,12 +43,12 @@ func (h *providerHandler) Issuer() *url.URL {
//
// This method assumes that all of the OIDCProvider arguments have already been validated
// by someone else before they are passed to this method.
func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
c.mu.Lock()
defer c.mu.Unlock()
// Add all of the incoming providers.
for _, incomingProvider := range oidcProviders {
issuerString := incomingProvider.Issuer.String()
issuerString := incomingProvider.Issuer()
c.providerHandlers[issuerString] = &providerHandler{
provider: incomingProvider,
discoveryHandler: discovery.New(issuerString),
@ -70,9 +66,9 @@ func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
// ServeHTTP implements the http.Handler interface.
func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
providerHandler := c.findProviderHandlerByIssuerURL(req.Host, req.URL.Path)
providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path)
if providerHandler != nil {
if req.URL.Path == providerHandler.Issuer().Path+oidc.WellKnownURLPath {
if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath {
providerHandler.discoveryHandler.ServeHTTP(resp, req)
return // handled!
}
@ -94,20 +90,20 @@ func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
c.nextHandler.ServeHTTP(resp, req)
}
func (c *Manager) findProviderHandlerByIssuerURL(host, path string) *providerHandler {
func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler {
for _, providerHandler := range c.providerHandlers {
pi := providerHandler.Issuer()
// TODO do we need to compare scheme? not sure how to get it from the http.Request object
if host == pi.Host && strings.HasPrefix(path, pi.Path) { // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
// TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) {
return providerHandler
}
}
return nil
}
func findIssuerInListOfProviders(issuer string, oidcProviders []*OIDCProvider) bool {
for _, provider := range oidcProviders {
if provider.Issuer.String() == issuer {
func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool {
for _, oidcProvider := range oidcProviders {
if oidcProvider.Issuer() == issuer {
return true
}
}

View File

@ -4,6 +4,7 @@
package provider
import (
"fmt"
"net/url"
"strings"
@ -12,39 +13,68 @@ import (
// OIDCProvider represents all of the settings and state for an OIDC provider.
type OIDCProvider struct {
Issuer *url.URL
issuer string
issuerHost string
issuerPath string
}
// Validate returns an error if there is anything wrong with the provider settings, or
// returns nil if there is nothing wrong with the settings.
func (p *OIDCProvider) Validate() error {
if p.Issuer == nil {
return constable.Error(`provider must have an issuer`)
func NewOIDCProvider(issuer string) (*OIDCProvider, error) {
p := OIDCProvider{issuer: issuer}
err := p.validate()
if err != nil {
return nil, err
}
return &p, nil
}
if p.Issuer.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(p.Issuer.Scheme) {
func (p *OIDCProvider) validate() error {
if p.issuer == "" {
return constable.Error("provider must have an issuer")
}
issuerURL, err := url.Parse(p.issuer)
if err != nil {
return fmt.Errorf("could not parse issuer as URL: %w", err)
}
if issuerURL.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(issuerURL.Scheme) {
return constable.Error(`issuer must have "https" scheme`)
}
if p.Issuer.User != nil {
if issuerURL.User != nil {
return constable.Error(`issuer must not have username or password`)
}
if strings.HasSuffix(p.Issuer.Path, "/") {
if strings.HasSuffix(issuerURL.Path, "/") {
return constable.Error(`issuer must not have trailing slash in path`)
}
if p.Issuer.RawQuery != "" {
if issuerURL.RawQuery != "" {
return constable.Error(`issuer must not have query`)
}
if p.Issuer.Fragment != "" {
if issuerURL.Fragment != "" {
return constable.Error(`issuer must not have fragment`)
}
p.issuerHost = issuerURL.Host
p.issuerPath = issuerURL.Path
return nil
}
func (p *OIDCProvider) Issuer() string {
return p.issuer
}
func (p *OIDCProvider) IssuerHost() string {
return p.issuerHost
}
func (p *OIDCProvider) IssuerPath() string {
return p.issuerPath
}
func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool {
return scheme != "http"
}

View File

@ -4,7 +4,6 @@
package provider
import (
"net/url"
"testing"
"github.com/stretchr/testify/require"
@ -13,63 +12,62 @@ import (
func TestOIDCProviderValidations(t *testing.T) {
tests := []struct {
name string
issuer *url.URL
issuer string
wantError string
}{
{
name: "provider must have an issuer",
issuer: nil,
issuer: "",
wantError: "provider must have an issuer",
},
{
name: "no scheme",
issuer: must(url.Parse("tuna.com")),
issuer: "tuna.com",
wantError: `issuer must have "https" scheme`,
},
{
name: "bad scheme",
issuer: must(url.Parse("ftp://tuna.com")),
issuer: "ftp://tuna.com",
wantError: `issuer must have "https" scheme`,
},
{
name: "fragment",
issuer: must(url.Parse("https://tuna.com/fish#some-frag")),
issuer: "https://tuna.com/fish#some-frag",
wantError: `issuer must not have fragment`,
},
{
name: "query",
issuer: must(url.Parse("https://tuna.com?some=query")),
issuer: "https://tuna.com?some=query",
wantError: `issuer must not have query`,
},
{
name: "username",
issuer: must(url.Parse("https://username@tuna.com")),
issuer: "https://username@tuna.com",
wantError: `issuer must not have username or password`,
},
{
name: "password",
issuer: must(url.Parse("https://username:password@tuna.com")),
issuer: "https://username:password@tuna.com",
wantError: `issuer must not have username or password`,
},
{
name: "without path",
issuer: must(url.Parse("https://tuna.com")),
issuer: "https://tuna.com",
},
{
name: "with path",
issuer: must(url.Parse("https://tuna.com/fish/marlin")),
issuer: "https://tuna.com/fish/marlin",
},
{
name: "trailing slash in path",
issuer: must(url.Parse("https://tuna.com/")),
issuer: "https://tuna.com/",
wantError: `issuer must not have trailing slash in path`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
p := OIDCProvider{Issuer: tt.issuer}
err := p.Validate()
_, err := NewOIDCProvider(tt.issuer)
if tt.wantError != "" {
require.EqualError(t, err, tt.wantError)
} else {
@ -78,10 +76,3 @@ func TestOIDCProviderValidations(t *testing.T) {
})
}
}
func must(u *url.URL, err error) *url.URL {
if err != nil {
panic(err)
}
return u
}

View File

@ -175,8 +175,8 @@ func requireWellKnownEndpointIsWorking(t *testing.T, issuerName string) {
// Check that the response matches our expectations.
expectedResultTemplate := here.Doc(`{
"issuer": "%s",
"authorization_endpoint": "%s/oauth2/v0/auth",
"token_endpoint": "%s/oauth2/v0/token",
"authorization_endpoint": "%s/oauth2/authorize",
"token_endpoint": "%s/oauth2/token",
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
"token_endpoint_auth_signing_alg_values_supported": ["RS256"],
"jwks_uri": "%s/jwks.json",