Several small refactors related to OIDC providers
This commit is contained in:
parent
da00fc708f
commit
8b7d96f42c
@ -24,7 +24,7 @@ import (
|
||||
"go.pinniped.dev/internal/controller/supervisorconfig"
|
||||
"go.pinniped.dev/internal/controllerlib"
|
||||
"go.pinniped.dev/internal/downward"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/oidc/provider/manager"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -61,7 +61,7 @@ func waitForSignal() os.Signal {
|
||||
|
||||
func startControllers(
|
||||
ctx context.Context,
|
||||
issuerProvider *provider.Manager,
|
||||
issuerProvider *manager.Manager,
|
||||
pinnipedClient pinnipedclientset.Interface,
|
||||
pinnipedInformers pinnipedinformers.SharedInformerFactory,
|
||||
) {
|
||||
@ -114,7 +114,7 @@ func run(serverInstallationNamespace string) error {
|
||||
pinnipedinformers.WithNamespace(serverInstallationNamespace),
|
||||
)
|
||||
|
||||
oidProvidersManager := provider.NewManager(http.NotFoundHandler())
|
||||
oidProvidersManager := manager.NewManager(http.NotFoundHandler())
|
||||
startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers)
|
||||
|
||||
//nolint: gosec // Intentionally binding to all network interfaces.
|
||||
|
@ -6,8 +6,8 @@ package supervisorconfig
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"go.pinniped.dev/internal/multierror"
|
||||
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/labels"
|
||||
@ -73,10 +73,10 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
|
||||
|
||||
issuerCounts := make(map[string]int)
|
||||
for _, opc := range all {
|
||||
issuerCounts[opc.Spec.Issuer] = issuerCounts[opc.Spec.Issuer] + 1
|
||||
issuerCounts[opc.Spec.Issuer]++
|
||||
}
|
||||
|
||||
errs := newMultiError()
|
||||
errs := multierror.New()
|
||||
|
||||
oidcProviders := make([]*provider.OIDCProvider, 0)
|
||||
for _, opc := range all {
|
||||
@ -88,36 +88,21 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
|
||||
configv1alpha1.DuplicateOIDCProviderStatus,
|
||||
"Duplicate issuer",
|
||||
); err != nil {
|
||||
errs.add(fmt.Errorf("could not update status: %w", err))
|
||||
errs.Add(fmt.Errorf("could not update status: %w", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
issuerURL, err := url.Parse(opc.Spec.Issuer)
|
||||
oidcProvider, err := provider.NewOIDCProvider(opc.Spec.Issuer)
|
||||
if err != nil {
|
||||
if err := c.updateStatus(
|
||||
ctx.Context,
|
||||
opc.Namespace,
|
||||
opc.Name,
|
||||
configv1alpha1.InvalidOIDCProviderStatus,
|
||||
"Invalid issuer URL: "+err.Error(),
|
||||
"Invalid: "+err.Error(),
|
||||
); err != nil {
|
||||
errs.add(fmt.Errorf("could not update status: %w", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
oidcProvider := &provider.OIDCProvider{Issuer: issuerURL}
|
||||
err = oidcProvider.Validate()
|
||||
if err != nil {
|
||||
if err := c.updateStatus(
|
||||
ctx.Context,
|
||||
opc.Namespace,
|
||||
opc.Name,
|
||||
configv1alpha1.InvalidOIDCProviderStatus,
|
||||
"Invalid issuer: "+err.Error(),
|
||||
); err != nil {
|
||||
errs.add(fmt.Errorf("could not update status: %w", err))
|
||||
errs.Add(fmt.Errorf("could not update status: %w", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
@ -130,14 +115,13 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
|
||||
configv1alpha1.SuccessOIDCProviderStatus,
|
||||
"Provider successfully created",
|
||||
); err != nil {
|
||||
// errs.add(fmt.Errorf("could not update status: %w", err))
|
||||
return fmt.Errorf("could not update status: %w", err)
|
||||
errs.Add(fmt.Errorf("could not update status: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
c.providerSetter.SetProviders(oidcProviders...)
|
||||
|
||||
return errs.errOrNil()
|
||||
return errs.ErrOrNil()
|
||||
}
|
||||
|
||||
func (c *oidcProviderConfigWatcherController) updateStatus(
|
||||
@ -171,33 +155,3 @@ func (c *oidcProviderConfigWatcherController) updateStatus(
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
type multiError []error
|
||||
|
||||
func newMultiError() multiError {
|
||||
return make([]error, 0)
|
||||
}
|
||||
|
||||
func (m *multiError) add(err error) {
|
||||
*m = append(*m, err)
|
||||
}
|
||||
|
||||
func (m multiError) len() int {
|
||||
return len(m)
|
||||
}
|
||||
|
||||
func (m multiError) Error() string {
|
||||
sb := strings.Builder{}
|
||||
fmt.Fprintf(&sb, "%d errors:", m.len())
|
||||
for _, err := range m {
|
||||
fmt.Fprintf(&sb, "\n- %s", err.Error())
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m multiError) errOrNil() error {
|
||||
if m.len() > 0 {
|
||||
return m
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
39
internal/multierror/multierror.go
Normal file
39
internal/multierror/multierror.go
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package multierror
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type multiError []error
|
||||
|
||||
func New() multiError { //nolint:golint // returning a private type for encapsulation purposes
|
||||
return make([]error, 0)
|
||||
}
|
||||
|
||||
func (m *multiError) Add(err error) {
|
||||
*m = append(*m, err)
|
||||
}
|
||||
|
||||
func (m multiError) len() int {
|
||||
return len(m)
|
||||
}
|
||||
|
||||
func (m multiError) Error() string {
|
||||
sb := strings.Builder{}
|
||||
_, _ = fmt.Fprintf(&sb, "%d errors:", m.len())
|
||||
for _, err := range m {
|
||||
_, _ = fmt.Fprintf(&sb, "\n- %s", err.Error())
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m multiError) ErrOrNil() error {
|
||||
if m.len() > 0 {
|
||||
return m
|
||||
}
|
||||
return nil
|
||||
}
|
@ -6,8 +6,9 @@ package discovery
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
)
|
||||
|
||||
// Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the
|
||||
@ -50,9 +51,9 @@ func New(issuerURL string) http.Handler {
|
||||
|
||||
oidcConfig := Metadata{
|
||||
Issuer: issuerURL,
|
||||
AuthorizationEndpoint: fmt.Sprintf("%s/oauth2/v0/auth", issuerURL),
|
||||
TokenEndpoint: fmt.Sprintf("%s/oauth2/v0/token", issuerURL),
|
||||
JWKSURI: fmt.Sprintf("%s/jwks.json", issuerURL),
|
||||
AuthorizationEndpoint: issuerURL + oidc.AuthorizationEndpointPath,
|
||||
TokenEndpoint: issuerURL + oidc.TokenEndpointPath,
|
||||
JWKSURI: issuerURL + oidc.JWKSEndpointPath,
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
SubjectTypesSupported: []string{"public"},
|
||||
IDTokenSigningAlgValuesSupported: []string{"RS256"},
|
||||
|
@ -30,13 +30,13 @@ func TestDiscovery(t *testing.T) {
|
||||
name: "happy path",
|
||||
issuer: "https://some-issuer.com/some/path",
|
||||
method: http.MethodGet,
|
||||
path: "/some/path" + oidc.WellKnownURLPath,
|
||||
path: "/some/path" + oidc.WellKnownEndpointPath,
|
||||
wantStatus: http.StatusOK,
|
||||
wantContentType: "application/json",
|
||||
wantBody: &Metadata{
|
||||
Issuer: "https://some-issuer.com/some/path",
|
||||
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/v0/auth",
|
||||
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/v0/token",
|
||||
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/authorize",
|
||||
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/token",
|
||||
JWKSURI: "https://some-issuer.com/some/path/jwks.json",
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
SubjectTypesSupported: []string{"public"},
|
||||
@ -51,7 +51,7 @@ func TestDiscovery(t *testing.T) {
|
||||
name: "bad method",
|
||||
issuer: "https://some-issuer.com",
|
||||
method: http.MethodPost,
|
||||
path: oidc.WellKnownURLPath,
|
||||
path: oidc.WellKnownEndpointPath,
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: map[string]string{
|
||||
"error": "Method not allowed (try GET)",
|
||||
|
@ -5,5 +5,8 @@
|
||||
package oidc
|
||||
|
||||
const (
|
||||
WellKnownURLPath = "/.well-known/openid-configuration"
|
||||
WellKnownEndpointPath = "/.well-known/openid-configuration"
|
||||
AuthorizationEndpointPath = "/oauth2/authorize"
|
||||
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
|
||||
JWKSEndpointPath = "/jwks.json"
|
||||
)
|
||||
|
@ -1,11 +1,10 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package provider
|
||||
package manager
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@ -13,6 +12,7 @@ import (
|
||||
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/discovery"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
)
|
||||
|
||||
// Manager can manage multiple active OIDC providers. It acts as a request router for them.
|
||||
@ -24,21 +24,17 @@ type Manager struct {
|
||||
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
||||
}
|
||||
|
||||
// New returns an empty Manager.
|
||||
// NewManager returns an empty Manager.
|
||||
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
||||
func NewManager(nextHandler http.Handler) *Manager {
|
||||
return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler}
|
||||
}
|
||||
|
||||
type providerHandler struct {
|
||||
provider *OIDCProvider
|
||||
provider *provider.OIDCProvider
|
||||
discoveryHandler http.Handler
|
||||
}
|
||||
|
||||
func (h *providerHandler) Issuer() *url.URL {
|
||||
return h.provider.Issuer
|
||||
}
|
||||
|
||||
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string
|
||||
// as the name of the provider to decide if it is an add or update operation.
|
||||
//
|
||||
@ -47,12 +43,12 @@ func (h *providerHandler) Issuer() *url.URL {
|
||||
//
|
||||
// This method assumes that all of the OIDCProvider arguments have already been validated
|
||||
// by someone else before they are passed to this method.
|
||||
func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
|
||||
func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
// Add all of the incoming providers.
|
||||
for _, incomingProvider := range oidcProviders {
|
||||
issuerString := incomingProvider.Issuer.String()
|
||||
issuerString := incomingProvider.Issuer()
|
||||
c.providerHandlers[issuerString] = &providerHandler{
|
||||
provider: incomingProvider,
|
||||
discoveryHandler: discovery.New(issuerString),
|
||||
@ -70,9 +66,9 @@ func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
|
||||
|
||||
// ServeHTTP implements the http.Handler interface.
|
||||
func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
providerHandler := c.findProviderHandlerByIssuerURL(req.Host, req.URL.Path)
|
||||
providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path)
|
||||
if providerHandler != nil {
|
||||
if req.URL.Path == providerHandler.Issuer().Path+oidc.WellKnownURLPath {
|
||||
if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath {
|
||||
providerHandler.discoveryHandler.ServeHTTP(resp, req)
|
||||
return // handled!
|
||||
}
|
||||
@ -94,20 +90,20 @@ func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||
c.nextHandler.ServeHTTP(resp, req)
|
||||
}
|
||||
|
||||
func (c *Manager) findProviderHandlerByIssuerURL(host, path string) *providerHandler {
|
||||
func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler {
|
||||
for _, providerHandler := range c.providerHandlers {
|
||||
pi := providerHandler.Issuer()
|
||||
// TODO do we need to compare scheme? not sure how to get it from the http.Request object
|
||||
if host == pi.Host && strings.HasPrefix(path, pi.Path) { // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
|
||||
// TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
|
||||
if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) {
|
||||
return providerHandler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIssuerInListOfProviders(issuer string, oidcProviders []*OIDCProvider) bool {
|
||||
for _, provider := range oidcProviders {
|
||||
if provider.Issuer.String() == issuer {
|
||||
func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool {
|
||||
for _, oidcProvider := range oidcProviders {
|
||||
if oidcProvider.Issuer() == issuer {
|
||||
return true
|
||||
}
|
||||
}
|
@ -4,6 +4,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
@ -12,39 +13,68 @@ import (
|
||||
|
||||
// OIDCProvider represents all of the settings and state for an OIDC provider.
|
||||
type OIDCProvider struct {
|
||||
Issuer *url.URL
|
||||
issuer string
|
||||
issuerHost string
|
||||
issuerPath string
|
||||
}
|
||||
|
||||
// Validate returns an error if there is anything wrong with the provider settings, or
|
||||
// returns nil if there is nothing wrong with the settings.
|
||||
func (p *OIDCProvider) Validate() error {
|
||||
if p.Issuer == nil {
|
||||
return constable.Error(`provider must have an issuer`)
|
||||
func NewOIDCProvider(issuer string) (*OIDCProvider, error) {
|
||||
p := OIDCProvider{issuer: issuer}
|
||||
err := p.validate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) validate() error {
|
||||
if p.issuer == "" {
|
||||
return constable.Error("provider must have an issuer")
|
||||
}
|
||||
|
||||
if p.Issuer.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(p.Issuer.Scheme) {
|
||||
issuerURL, err := url.Parse(p.issuer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not parse issuer as URL: %w", err)
|
||||
}
|
||||
|
||||
if issuerURL.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(issuerURL.Scheme) {
|
||||
return constable.Error(`issuer must have "https" scheme`)
|
||||
}
|
||||
|
||||
if p.Issuer.User != nil {
|
||||
if issuerURL.User != nil {
|
||||
return constable.Error(`issuer must not have username or password`)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(p.Issuer.Path, "/") {
|
||||
if strings.HasSuffix(issuerURL.Path, "/") {
|
||||
return constable.Error(`issuer must not have trailing slash in path`)
|
||||
}
|
||||
|
||||
if p.Issuer.RawQuery != "" {
|
||||
if issuerURL.RawQuery != "" {
|
||||
return constable.Error(`issuer must not have query`)
|
||||
}
|
||||
|
||||
if p.Issuer.Fragment != "" {
|
||||
if issuerURL.Fragment != "" {
|
||||
return constable.Error(`issuer must not have fragment`)
|
||||
}
|
||||
|
||||
p.issuerHost = issuerURL.Host
|
||||
p.issuerPath = issuerURL.Path
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) Issuer() string {
|
||||
return p.issuer
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) IssuerHost() string {
|
||||
return p.issuerHost
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) IssuerPath() string {
|
||||
return p.issuerPath
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool {
|
||||
return scheme != "http"
|
||||
}
|
||||
|
@ -4,7 +4,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -13,63 +12,62 @@ import (
|
||||
func TestOIDCProviderValidations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
issuer *url.URL
|
||||
issuer string
|
||||
wantError string
|
||||
}{
|
||||
{
|
||||
name: "provider must have an issuer",
|
||||
issuer: nil,
|
||||
issuer: "",
|
||||
wantError: "provider must have an issuer",
|
||||
},
|
||||
{
|
||||
name: "no scheme",
|
||||
issuer: must(url.Parse("tuna.com")),
|
||||
issuer: "tuna.com",
|
||||
wantError: `issuer must have "https" scheme`,
|
||||
},
|
||||
{
|
||||
name: "bad scheme",
|
||||
issuer: must(url.Parse("ftp://tuna.com")),
|
||||
issuer: "ftp://tuna.com",
|
||||
wantError: `issuer must have "https" scheme`,
|
||||
},
|
||||
{
|
||||
name: "fragment",
|
||||
issuer: must(url.Parse("https://tuna.com/fish#some-frag")),
|
||||
issuer: "https://tuna.com/fish#some-frag",
|
||||
wantError: `issuer must not have fragment`,
|
||||
},
|
||||
{
|
||||
name: "query",
|
||||
issuer: must(url.Parse("https://tuna.com?some=query")),
|
||||
issuer: "https://tuna.com?some=query",
|
||||
wantError: `issuer must not have query`,
|
||||
},
|
||||
{
|
||||
name: "username",
|
||||
issuer: must(url.Parse("https://username@tuna.com")),
|
||||
issuer: "https://username@tuna.com",
|
||||
wantError: `issuer must not have username or password`,
|
||||
},
|
||||
{
|
||||
name: "password",
|
||||
issuer: must(url.Parse("https://username:password@tuna.com")),
|
||||
issuer: "https://username:password@tuna.com",
|
||||
wantError: `issuer must not have username or password`,
|
||||
},
|
||||
{
|
||||
name: "without path",
|
||||
issuer: must(url.Parse("https://tuna.com")),
|
||||
issuer: "https://tuna.com",
|
||||
},
|
||||
{
|
||||
name: "with path",
|
||||
issuer: must(url.Parse("https://tuna.com/fish/marlin")),
|
||||
issuer: "https://tuna.com/fish/marlin",
|
||||
},
|
||||
{
|
||||
name: "trailing slash in path",
|
||||
issuer: must(url.Parse("https://tuna.com/")),
|
||||
issuer: "https://tuna.com/",
|
||||
wantError: `issuer must not have trailing slash in path`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := OIDCProvider{Issuer: tt.issuer}
|
||||
err := p.Validate()
|
||||
_, err := NewOIDCProvider(tt.issuer)
|
||||
if tt.wantError != "" {
|
||||
require.EqualError(t, err, tt.wantError)
|
||||
} else {
|
||||
@ -78,10 +76,3 @@ func TestOIDCProviderValidations(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func must(u *url.URL, err error) *url.URL {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
@ -175,8 +175,8 @@ func requireWellKnownEndpointIsWorking(t *testing.T, issuerName string) {
|
||||
// Check that the response matches our expectations.
|
||||
expectedResultTemplate := here.Doc(`{
|
||||
"issuer": "%s",
|
||||
"authorization_endpoint": "%s/oauth2/v0/auth",
|
||||
"token_endpoint": "%s/oauth2/v0/token",
|
||||
"authorization_endpoint": "%s/oauth2/authorize",
|
||||
"token_endpoint": "%s/oauth2/token",
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
|
||||
"token_endpoint_auth_signing_alg_values_supported": ["RS256"],
|
||||
"jwks_uri": "%s/jwks.json",
|
||||
|
Loading…
Reference in New Issue
Block a user