Several small refactors related to OIDC providers
This commit is contained in:
parent
da00fc708f
commit
8b7d96f42c
@ -24,7 +24,7 @@ import (
|
|||||||
"go.pinniped.dev/internal/controller/supervisorconfig"
|
"go.pinniped.dev/internal/controller/supervisorconfig"
|
||||||
"go.pinniped.dev/internal/controllerlib"
|
"go.pinniped.dev/internal/controllerlib"
|
||||||
"go.pinniped.dev/internal/downward"
|
"go.pinniped.dev/internal/downward"
|
||||||
"go.pinniped.dev/internal/oidc/provider"
|
"go.pinniped.dev/internal/oidc/provider/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -61,7 +61,7 @@ func waitForSignal() os.Signal {
|
|||||||
|
|
||||||
func startControllers(
|
func startControllers(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
issuerProvider *provider.Manager,
|
issuerProvider *manager.Manager,
|
||||||
pinnipedClient pinnipedclientset.Interface,
|
pinnipedClient pinnipedclientset.Interface,
|
||||||
pinnipedInformers pinnipedinformers.SharedInformerFactory,
|
pinnipedInformers pinnipedinformers.SharedInformerFactory,
|
||||||
) {
|
) {
|
||||||
@ -114,7 +114,7 @@ func run(serverInstallationNamespace string) error {
|
|||||||
pinnipedinformers.WithNamespace(serverInstallationNamespace),
|
pinnipedinformers.WithNamespace(serverInstallationNamespace),
|
||||||
)
|
)
|
||||||
|
|
||||||
oidProvidersManager := provider.NewManager(http.NotFoundHandler())
|
oidProvidersManager := manager.NewManager(http.NotFoundHandler())
|
||||||
startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers)
|
startControllers(ctx, oidProvidersManager, pinnipedClient, pinnipedInformers)
|
||||||
|
|
||||||
//nolint: gosec // Intentionally binding to all network interfaces.
|
//nolint: gosec // Intentionally binding to all network interfaces.
|
||||||
|
@ -6,8 +6,8 @@ package supervisorconfig
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"go.pinniped.dev/internal/multierror"
|
||||||
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
"k8s.io/apimachinery/pkg/labels"
|
"k8s.io/apimachinery/pkg/labels"
|
||||||
@ -73,10 +73,10 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
|
|||||||
|
|
||||||
issuerCounts := make(map[string]int)
|
issuerCounts := make(map[string]int)
|
||||||
for _, opc := range all {
|
for _, opc := range all {
|
||||||
issuerCounts[opc.Spec.Issuer] = issuerCounts[opc.Spec.Issuer] + 1
|
issuerCounts[opc.Spec.Issuer]++
|
||||||
}
|
}
|
||||||
|
|
||||||
errs := newMultiError()
|
errs := multierror.New()
|
||||||
|
|
||||||
oidcProviders := make([]*provider.OIDCProvider, 0)
|
oidcProviders := make([]*provider.OIDCProvider, 0)
|
||||||
for _, opc := range all {
|
for _, opc := range all {
|
||||||
@ -88,36 +88,21 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
|
|||||||
configv1alpha1.DuplicateOIDCProviderStatus,
|
configv1alpha1.DuplicateOIDCProviderStatus,
|
||||||
"Duplicate issuer",
|
"Duplicate issuer",
|
||||||
); err != nil {
|
); err != nil {
|
||||||
errs.add(fmt.Errorf("could not update status: %w", err))
|
errs.Add(fmt.Errorf("could not update status: %w", err))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
issuerURL, err := url.Parse(opc.Spec.Issuer)
|
oidcProvider, err := provider.NewOIDCProvider(opc.Spec.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err := c.updateStatus(
|
if err := c.updateStatus(
|
||||||
ctx.Context,
|
ctx.Context,
|
||||||
opc.Namespace,
|
opc.Namespace,
|
||||||
opc.Name,
|
opc.Name,
|
||||||
configv1alpha1.InvalidOIDCProviderStatus,
|
configv1alpha1.InvalidOIDCProviderStatus,
|
||||||
"Invalid issuer URL: "+err.Error(),
|
"Invalid: "+err.Error(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
errs.add(fmt.Errorf("could not update status: %w", err))
|
errs.Add(fmt.Errorf("could not update status: %w", err))
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
oidcProvider := &provider.OIDCProvider{Issuer: issuerURL}
|
|
||||||
err = oidcProvider.Validate()
|
|
||||||
if err != nil {
|
|
||||||
if err := c.updateStatus(
|
|
||||||
ctx.Context,
|
|
||||||
opc.Namespace,
|
|
||||||
opc.Name,
|
|
||||||
configv1alpha1.InvalidOIDCProviderStatus,
|
|
||||||
"Invalid issuer: "+err.Error(),
|
|
||||||
); err != nil {
|
|
||||||
errs.add(fmt.Errorf("could not update status: %w", err))
|
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -130,14 +115,13 @@ func (c *oidcProviderConfigWatcherController) Sync(ctx controllerlib.Context) er
|
|||||||
configv1alpha1.SuccessOIDCProviderStatus,
|
configv1alpha1.SuccessOIDCProviderStatus,
|
||||||
"Provider successfully created",
|
"Provider successfully created",
|
||||||
); err != nil {
|
); err != nil {
|
||||||
// errs.add(fmt.Errorf("could not update status: %w", err))
|
errs.Add(fmt.Errorf("could not update status: %w", err))
|
||||||
return fmt.Errorf("could not update status: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.providerSetter.SetProviders(oidcProviders...)
|
c.providerSetter.SetProviders(oidcProviders...)
|
||||||
|
|
||||||
return errs.errOrNil()
|
return errs.ErrOrNil()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *oidcProviderConfigWatcherController) updateStatus(
|
func (c *oidcProviderConfigWatcherController) updateStatus(
|
||||||
@ -171,33 +155,3 @@ func (c *oidcProviderConfigWatcherController) updateStatus(
|
|||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
type multiError []error
|
|
||||||
|
|
||||||
func newMultiError() multiError {
|
|
||||||
return make([]error, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *multiError) add(err error) {
|
|
||||||
*m = append(*m, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiError) len() int {
|
|
||||||
return len(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiError) Error() string {
|
|
||||||
sb := strings.Builder{}
|
|
||||||
fmt.Fprintf(&sb, "%d errors:", m.len())
|
|
||||||
for _, err := range m {
|
|
||||||
fmt.Fprintf(&sb, "\n- %s", err.Error())
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiError) errOrNil() error {
|
|
||||||
if m.len() > 0 {
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
39
internal/multierror/multierror.go
Normal file
39
internal/multierror/multierror.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package multierror
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type multiError []error
|
||||||
|
|
||||||
|
func New() multiError { //nolint:golint // returning a private type for encapsulation purposes
|
||||||
|
return make([]error, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *multiError) Add(err error) {
|
||||||
|
*m = append(*m, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m multiError) len() int {
|
||||||
|
return len(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m multiError) Error() string {
|
||||||
|
sb := strings.Builder{}
|
||||||
|
_, _ = fmt.Fprintf(&sb, "%d errors:", m.len())
|
||||||
|
for _, err := range m {
|
||||||
|
_, _ = fmt.Fprintf(&sb, "\n- %s", err.Error())
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m multiError) ErrOrNil() error {
|
||||||
|
if m.len() > 0 {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -6,8 +6,9 @@ package discovery
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"go.pinniped.dev/internal/oidc"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the
|
// Metadata holds all fields (that we care about) from the OpenID Provider Metadata section in the
|
||||||
@ -50,9 +51,9 @@ func New(issuerURL string) http.Handler {
|
|||||||
|
|
||||||
oidcConfig := Metadata{
|
oidcConfig := Metadata{
|
||||||
Issuer: issuerURL,
|
Issuer: issuerURL,
|
||||||
AuthorizationEndpoint: fmt.Sprintf("%s/oauth2/v0/auth", issuerURL),
|
AuthorizationEndpoint: issuerURL + oidc.AuthorizationEndpointPath,
|
||||||
TokenEndpoint: fmt.Sprintf("%s/oauth2/v0/token", issuerURL),
|
TokenEndpoint: issuerURL + oidc.TokenEndpointPath,
|
||||||
JWKSURI: fmt.Sprintf("%s/jwks.json", issuerURL),
|
JWKSURI: issuerURL + oidc.JWKSEndpointPath,
|
||||||
ResponseTypesSupported: []string{"code"},
|
ResponseTypesSupported: []string{"code"},
|
||||||
SubjectTypesSupported: []string{"public"},
|
SubjectTypesSupported: []string{"public"},
|
||||||
IDTokenSigningAlgValuesSupported: []string{"RS256"},
|
IDTokenSigningAlgValuesSupported: []string{"RS256"},
|
||||||
|
@ -30,13 +30,13 @@ func TestDiscovery(t *testing.T) {
|
|||||||
name: "happy path",
|
name: "happy path",
|
||||||
issuer: "https://some-issuer.com/some/path",
|
issuer: "https://some-issuer.com/some/path",
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: "/some/path" + oidc.WellKnownURLPath,
|
path: "/some/path" + oidc.WellKnownEndpointPath,
|
||||||
wantStatus: http.StatusOK,
|
wantStatus: http.StatusOK,
|
||||||
wantContentType: "application/json",
|
wantContentType: "application/json",
|
||||||
wantBody: &Metadata{
|
wantBody: &Metadata{
|
||||||
Issuer: "https://some-issuer.com/some/path",
|
Issuer: "https://some-issuer.com/some/path",
|
||||||
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/v0/auth",
|
AuthorizationEndpoint: "https://some-issuer.com/some/path/oauth2/authorize",
|
||||||
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/v0/token",
|
TokenEndpoint: "https://some-issuer.com/some/path/oauth2/token",
|
||||||
JWKSURI: "https://some-issuer.com/some/path/jwks.json",
|
JWKSURI: "https://some-issuer.com/some/path/jwks.json",
|
||||||
ResponseTypesSupported: []string{"code"},
|
ResponseTypesSupported: []string{"code"},
|
||||||
SubjectTypesSupported: []string{"public"},
|
SubjectTypesSupported: []string{"public"},
|
||||||
@ -51,7 +51,7 @@ func TestDiscovery(t *testing.T) {
|
|||||||
name: "bad method",
|
name: "bad method",
|
||||||
issuer: "https://some-issuer.com",
|
issuer: "https://some-issuer.com",
|
||||||
method: http.MethodPost,
|
method: http.MethodPost,
|
||||||
path: oidc.WellKnownURLPath,
|
path: oidc.WellKnownEndpointPath,
|
||||||
wantStatus: http.StatusMethodNotAllowed,
|
wantStatus: http.StatusMethodNotAllowed,
|
||||||
wantBody: map[string]string{
|
wantBody: map[string]string{
|
||||||
"error": "Method not allowed (try GET)",
|
"error": "Method not allowed (try GET)",
|
||||||
|
@ -5,5 +5,8 @@
|
|||||||
package oidc
|
package oidc
|
||||||
|
|
||||||
const (
|
const (
|
||||||
WellKnownURLPath = "/.well-known/openid-configuration"
|
WellKnownEndpointPath = "/.well-known/openid-configuration"
|
||||||
|
AuthorizationEndpointPath = "/oauth2/authorize"
|
||||||
|
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
|
||||||
|
JWKSEndpointPath = "/jwks.json"
|
||||||
)
|
)
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
package provider
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -13,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"go.pinniped.dev/internal/oidc"
|
"go.pinniped.dev/internal/oidc"
|
||||||
"go.pinniped.dev/internal/oidc/discovery"
|
"go.pinniped.dev/internal/oidc/discovery"
|
||||||
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager can manage multiple active OIDC providers. It acts as a request router for them.
|
// Manager can manage multiple active OIDC providers. It acts as a request router for them.
|
||||||
@ -24,21 +24,17 @@ type Manager struct {
|
|||||||
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns an empty Manager.
|
// NewManager returns an empty Manager.
|
||||||
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
||||||
func NewManager(nextHandler http.Handler) *Manager {
|
func NewManager(nextHandler http.Handler) *Manager {
|
||||||
return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler}
|
return &Manager{providerHandlers: make(map[string]*providerHandler), nextHandler: nextHandler}
|
||||||
}
|
}
|
||||||
|
|
||||||
type providerHandler struct {
|
type providerHandler struct {
|
||||||
provider *OIDCProvider
|
provider *provider.OIDCProvider
|
||||||
discoveryHandler http.Handler
|
discoveryHandler http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *providerHandler) Issuer() *url.URL {
|
|
||||||
return h.provider.Issuer
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string
|
// SetProviders adds or updates all the given providerHandlers using each provider's issuer string
|
||||||
// as the name of the provider to decide if it is an add or update operation.
|
// as the name of the provider to decide if it is an add or update operation.
|
||||||
//
|
//
|
||||||
@ -47,12 +43,12 @@ func (h *providerHandler) Issuer() *url.URL {
|
|||||||
//
|
//
|
||||||
// This method assumes that all of the OIDCProvider arguments have already been validated
|
// This method assumes that all of the OIDCProvider arguments have already been validated
|
||||||
// by someone else before they are passed to this method.
|
// by someone else before they are passed to this method.
|
||||||
func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
|
func (c *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
// Add all of the incoming providers.
|
// Add all of the incoming providers.
|
||||||
for _, incomingProvider := range oidcProviders {
|
for _, incomingProvider := range oidcProviders {
|
||||||
issuerString := incomingProvider.Issuer.String()
|
issuerString := incomingProvider.Issuer()
|
||||||
c.providerHandlers[issuerString] = &providerHandler{
|
c.providerHandlers[issuerString] = &providerHandler{
|
||||||
provider: incomingProvider,
|
provider: incomingProvider,
|
||||||
discoveryHandler: discovery.New(issuerString),
|
discoveryHandler: discovery.New(issuerString),
|
||||||
@ -70,9 +66,9 @@ func (c *Manager) SetProviders(oidcProviders ...*OIDCProvider) {
|
|||||||
|
|
||||||
// ServeHTTP implements the http.Handler interface.
|
// ServeHTTP implements the http.Handler interface.
|
||||||
func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
||||||
providerHandler := c.findProviderHandlerByIssuerURL(req.Host, req.URL.Path)
|
providerHandler := c.findProviderHandlerByIssuer(req.Host, req.URL.Path)
|
||||||
if providerHandler != nil {
|
if providerHandler != nil {
|
||||||
if req.URL.Path == providerHandler.Issuer().Path+oidc.WellKnownURLPath {
|
if req.URL.Path == providerHandler.provider.IssuerPath()+oidc.WellKnownEndpointPath {
|
||||||
providerHandler.discoveryHandler.ServeHTTP(resp, req)
|
providerHandler.discoveryHandler.ServeHTTP(resp, req)
|
||||||
return // handled!
|
return // handled!
|
||||||
}
|
}
|
||||||
@ -94,20 +90,20 @@ func (c *Manager) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
|
|||||||
c.nextHandler.ServeHTTP(resp, req)
|
c.nextHandler.ServeHTTP(resp, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Manager) findProviderHandlerByIssuerURL(host, path string) *providerHandler {
|
func (c *Manager) findProviderHandlerByIssuer(host, path string) *providerHandler {
|
||||||
for _, providerHandler := range c.providerHandlers {
|
for _, providerHandler := range c.providerHandlers {
|
||||||
pi := providerHandler.Issuer()
|
|
||||||
// TODO do we need to compare scheme? not sure how to get it from the http.Request object
|
// TODO do we need to compare scheme? not sure how to get it from the http.Request object
|
||||||
if host == pi.Host && strings.HasPrefix(path, pi.Path) { // TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
|
// TODO probably need better logic here? also maybe needs some of the logic from inside ServeMux
|
||||||
|
if host == providerHandler.provider.IssuerHost() && strings.HasPrefix(path, providerHandler.provider.IssuerPath()) {
|
||||||
return providerHandler
|
return providerHandler
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findIssuerInListOfProviders(issuer string, oidcProviders []*OIDCProvider) bool {
|
func findIssuerInListOfProviders(issuer string, oidcProviders []*provider.OIDCProvider) bool {
|
||||||
for _, provider := range oidcProviders {
|
for _, oidcProvider := range oidcProviders {
|
||||||
if provider.Issuer.String() == issuer {
|
if oidcProvider.Issuer() == issuer {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -4,6 +4,7 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -12,39 +13,68 @@ import (
|
|||||||
|
|
||||||
// OIDCProvider represents all of the settings and state for an OIDC provider.
|
// OIDCProvider represents all of the settings and state for an OIDC provider.
|
||||||
type OIDCProvider struct {
|
type OIDCProvider struct {
|
||||||
Issuer *url.URL
|
issuer string
|
||||||
|
issuerHost string
|
||||||
|
issuerPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate returns an error if there is anything wrong with the provider settings, or
|
func NewOIDCProvider(issuer string) (*OIDCProvider, error) {
|
||||||
// returns nil if there is nothing wrong with the settings.
|
p := OIDCProvider{issuer: issuer}
|
||||||
func (p *OIDCProvider) Validate() error {
|
err := p.validate()
|
||||||
if p.Issuer == nil {
|
if err != nil {
|
||||||
return constable.Error(`provider must have an issuer`)
|
return nil, err
|
||||||
|
}
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) validate() error {
|
||||||
|
if p.issuer == "" {
|
||||||
|
return constable.Error("provider must have an issuer")
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Issuer.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(p.Issuer.Scheme) {
|
issuerURL, err := url.Parse(p.issuer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not parse issuer as URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if issuerURL.Scheme != "https" && p.removeMeAfterWeNoLongerNeedHTTPIssuerSupport(issuerURL.Scheme) {
|
||||||
return constable.Error(`issuer must have "https" scheme`)
|
return constable.Error(`issuer must have "https" scheme`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Issuer.User != nil {
|
if issuerURL.User != nil {
|
||||||
return constable.Error(`issuer must not have username or password`)
|
return constable.Error(`issuer must not have username or password`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasSuffix(p.Issuer.Path, "/") {
|
if strings.HasSuffix(issuerURL.Path, "/") {
|
||||||
return constable.Error(`issuer must not have trailing slash in path`)
|
return constable.Error(`issuer must not have trailing slash in path`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Issuer.RawQuery != "" {
|
if issuerURL.RawQuery != "" {
|
||||||
return constable.Error(`issuer must not have query`)
|
return constable.Error(`issuer must not have query`)
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Issuer.Fragment != "" {
|
if issuerURL.Fragment != "" {
|
||||||
return constable.Error(`issuer must not have fragment`)
|
return constable.Error(`issuer must not have fragment`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.issuerHost = issuerURL.Host
|
||||||
|
p.issuerPath = issuerURL.Path
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) Issuer() string {
|
||||||
|
return p.issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) IssuerHost() string {
|
||||||
|
return p.issuerHost
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) IssuerPath() string {
|
||||||
|
return p.issuerPath
|
||||||
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool {
|
func (p *OIDCProvider) removeMeAfterWeNoLongerNeedHTTPIssuerSupport(scheme string) bool {
|
||||||
return scheme != "http"
|
return scheme != "http"
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -13,63 +12,62 @@ import (
|
|||||||
func TestOIDCProviderValidations(t *testing.T) {
|
func TestOIDCProviderValidations(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
issuer *url.URL
|
issuer string
|
||||||
wantError string
|
wantError string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "provider must have an issuer",
|
name: "provider must have an issuer",
|
||||||
issuer: nil,
|
issuer: "",
|
||||||
wantError: "provider must have an issuer",
|
wantError: "provider must have an issuer",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no scheme",
|
name: "no scheme",
|
||||||
issuer: must(url.Parse("tuna.com")),
|
issuer: "tuna.com",
|
||||||
wantError: `issuer must have "https" scheme`,
|
wantError: `issuer must have "https" scheme`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "bad scheme",
|
name: "bad scheme",
|
||||||
issuer: must(url.Parse("ftp://tuna.com")),
|
issuer: "ftp://tuna.com",
|
||||||
wantError: `issuer must have "https" scheme`,
|
wantError: `issuer must have "https" scheme`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fragment",
|
name: "fragment",
|
||||||
issuer: must(url.Parse("https://tuna.com/fish#some-frag")),
|
issuer: "https://tuna.com/fish#some-frag",
|
||||||
wantError: `issuer must not have fragment`,
|
wantError: `issuer must not have fragment`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "query",
|
name: "query",
|
||||||
issuer: must(url.Parse("https://tuna.com?some=query")),
|
issuer: "https://tuna.com?some=query",
|
||||||
wantError: `issuer must not have query`,
|
wantError: `issuer must not have query`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "username",
|
name: "username",
|
||||||
issuer: must(url.Parse("https://username@tuna.com")),
|
issuer: "https://username@tuna.com",
|
||||||
wantError: `issuer must not have username or password`,
|
wantError: `issuer must not have username or password`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "password",
|
name: "password",
|
||||||
issuer: must(url.Parse("https://username:password@tuna.com")),
|
issuer: "https://username:password@tuna.com",
|
||||||
wantError: `issuer must not have username or password`,
|
wantError: `issuer must not have username or password`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "without path",
|
name: "without path",
|
||||||
issuer: must(url.Parse("https://tuna.com")),
|
issuer: "https://tuna.com",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "with path",
|
name: "with path",
|
||||||
issuer: must(url.Parse("https://tuna.com/fish/marlin")),
|
issuer: "https://tuna.com/fish/marlin",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "trailing slash in path",
|
name: "trailing slash in path",
|
||||||
issuer: must(url.Parse("https://tuna.com/")),
|
issuer: "https://tuna.com/",
|
||||||
wantError: `issuer must not have trailing slash in path`,
|
wantError: `issuer must not have trailing slash in path`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p := OIDCProvider{Issuer: tt.issuer}
|
_, err := NewOIDCProvider(tt.issuer)
|
||||||
err := p.Validate()
|
|
||||||
if tt.wantError != "" {
|
if tt.wantError != "" {
|
||||||
require.EqualError(t, err, tt.wantError)
|
require.EqualError(t, err, tt.wantError)
|
||||||
} else {
|
} else {
|
||||||
@ -78,10 +76,3 @@ func TestOIDCProviderValidations(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func must(u *url.URL, err error) *url.URL {
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
@ -175,8 +175,8 @@ func requireWellKnownEndpointIsWorking(t *testing.T, issuerName string) {
|
|||||||
// Check that the response matches our expectations.
|
// Check that the response matches our expectations.
|
||||||
expectedResultTemplate := here.Doc(`{
|
expectedResultTemplate := here.Doc(`{
|
||||||
"issuer": "%s",
|
"issuer": "%s",
|
||||||
"authorization_endpoint": "%s/oauth2/v0/auth",
|
"authorization_endpoint": "%s/oauth2/authorize",
|
||||||
"token_endpoint": "%s/oauth2/v0/token",
|
"token_endpoint": "%s/oauth2/token",
|
||||||
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
|
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
|
||||||
"token_endpoint_auth_signing_alg_values_supported": ["RS256"],
|
"token_endpoint_auth_signing_alg_values_supported": ["RS256"],
|
||||||
"jwks_uri": "%s/jwks.json",
|
"jwks_uri": "%s/jwks.json",
|
||||||
|
Loading…
Reference in New Issue
Block a user