diff --git a/cmd/pinniped-supervisor/main.go b/cmd/pinniped-supervisor/main.go index d2bfc7f5..31f5dff8 100644 --- a/cmd/pinniped-supervisor/main.go +++ b/cmd/pinniped-supervisor/main.go @@ -196,7 +196,12 @@ func run(serverInstallationNamespace string, cfg *supervisor.Config) error { dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider() // OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux. - oidProvidersManager := manager.NewManager(healthMux, dynamicJWKSProvider, dynamicUpstreamIDPProvider) + oidProvidersManager := manager.NewManager( + healthMux, + dynamicJWKSProvider, + dynamicUpstreamIDPProvider, + kubeClient.CoreV1().Secrets(serverInstallationNamespace), + ) startControllers( ctx, diff --git a/cmd/pinniped/cmd/login_oidc.go b/cmd/pinniped/cmd/login_oidc.go index 1677c00f..c8d00662 100644 --- a/cmd/pinniped/cmd/login_oidc.go +++ b/cmd/pinniped/cmd/login_oidc.go @@ -20,6 +20,7 @@ import ( "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/filesession" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) //nolint: gochecknoinits @@ -27,7 +28,7 @@ func init() { loginCmd.AddCommand(oidcLoginCommand(oidcclient.Login)) } -func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error)) *cobra.Command { +func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error)) *cobra.Command { var ( cmd = cobra.Command{ Args: cobra.NoArgs, diff --git a/cmd/pinniped/cmd/login_oidc_test.go b/cmd/pinniped/cmd/login_oidc_test.go index 3a61934d..37cfac4e 100644 --- a/cmd/pinniped/cmd/login_oidc_test.go +++ b/cmd/pinniped/cmd/login_oidc_test.go @@ -13,6 +13,7 @@ import ( "go.pinniped.dev/internal/here" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestLoginOIDCCommand(t *testing.T) { @@ -92,12 +93,12 @@ func TestLoginOIDCCommand(t *testing.T) { gotClientID string gotOptions []oidcclient.Option ) - cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error) { + cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error) { gotIssuer = issuer gotClientID = clientID gotOptions = opts - return &oidcclient.Token{ - IDToken: &oidcclient.IDToken{ + return &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time1), }, diff --git a/go.mod b/go.mod index 4277bd9b..84237fd7 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/gorilla/securecookie v1.1.1 github.com/ory/fosite v0.35.1 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 + github.com/pkg/errors v0.9.1 github.com/sclevine/agouti v3.0.0+incompatible github.com/sclevine/spec v1.4.0 github.com/spf13/cobra v1.0.0 diff --git a/hack/lib/tilt/Tiltfile b/hack/lib/tilt/Tiltfile index e657e967..0c176e2c 100644 --- a/hack/lib/tilt/Tiltfile +++ b/hack/lib/tilt/Tiltfile @@ -103,6 +103,7 @@ k8s_yaml(local([ '--data-value-yaml', 'service_http_nodeport_nodeport=31234', '--data-value-yaml', 'service_https_nodeport_port=443', '--data-value-yaml', 'service_https_nodeport_nodeport=31243', + '--data-value-yaml', 'service_https_clusterip_port=443', '--data-value-yaml', 'custom_labels={mySupervisorCustomLabelName: mySupervisorCustomLabelValue}', ])) # Tell tilt to watch all of those files for changes. diff --git a/hack/prepare-for-integration-tests.sh b/hack/prepare-for-integration-tests.sh index 11e1fbf8..97bdcceb 100755 --- a/hack/prepare-for-integration-tests.sh +++ b/hack/prepare-for-integration-tests.sh @@ -230,6 +230,7 @@ if ! tilt_mode; then --data-value-yaml 'service_http_nodeport_nodeport=31234' \ --data-value-yaml 'service_https_nodeport_port=443' \ --data-value-yaml 'service_https_nodeport_nodeport=31243' \ + --data-value-yaml 'service_https_clusterip_port=443' \ >"$manifest" kapp deploy --yes --app "$supervisor_app_name" --diff-changes --file "$manifest" @@ -302,7 +303,7 @@ export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER=https://dex.dex.svc.cluster export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER_CA_BUNDLE="${test_ca_bundle_pem}" export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_ID=pinniped-supervisor export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_SECRET=pinniped-supervisor-secret -export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://127.0.0.1:12345/some/path/callback +export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME=pinny@example.com export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD=password diff --git a/internal/certauthority/certauthority.go b/internal/certauthority/certauthority.go index 13636db4..87bdd784 100644 --- a/internal/certauthority/certauthority.go +++ b/internal/certauthority/certauthority.go @@ -136,6 +136,13 @@ func (c *CA) Bundle() []byte { return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.caCertBytes}) } +// Pool returns the current CA signing bundle as a *x509.CertPool. +func (c *CA) Pool() *x509.CertPool { + pool := x509.NewCertPool() + pool.AppendCertsFromPEM(c.Bundle()) + return pool +} + // Issue a new server certificate for the given identity and duration. func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time.Duration) (*tls.Certificate, error) { // Choose a random 128 bit serial number. @@ -194,7 +201,7 @@ func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time. } // IssuePEM issues a new server certificate for the given identity and duration, returning it as a pair of -// PEM-formatted byte slices for the certificate and private key. +// PEM-formatted byte slices for the certificate and private key. func (c *CA) IssuePEM(subject pkix.Name, dnsNames []string, ttl time.Duration) ([]byte, []byte, error) { return toPEM(c.Issue(subject, dnsNames, nil, ttl)) } diff --git a/internal/certauthority/certauthority_test.go b/internal/certauthority/certauthority_test.go index 10e74743..4c1fdf8e 100644 --- a/internal/certauthority/certauthority_test.go +++ b/internal/certauthority/certauthority_test.go @@ -182,6 +182,16 @@ func TestBundle(t *testing.T) { }) } +func TestPool(t *testing.T) { + t.Run("success", func(t *testing.T) { + ca, err := New(pkix.Name{CommonName: "test"}, 1*time.Hour) + require.NoError(t, err) + + got := ca.Pool() + require.Len(t, got.Subjects(), 1) + }) +} + type errSigner struct { pubkey crypto.PublicKey err error diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go index bc3db3bf..7faa4d9c 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher.go @@ -17,6 +17,7 @@ import ( "github.com/coreos/go-oidc" "github.com/go-logr/logr" + "golang.org/x/oauth2" "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" @@ -30,6 +31,7 @@ import ( pinnipedcontroller "go.pinniped.dev/internal/controller" "go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/upstreamoidc" ) const ( @@ -62,21 +64,27 @@ const ( // IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations. type IDPCache interface { - SetIDPList([]provider.UpstreamOIDCIdentityProvider) + SetIDPList([]provider.UpstreamOIDCIdentityProviderI) } // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. type lruValidatorCache struct{ cache *cache.Expiring } -func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider { - if result, ok := c.cache.Get(c.cacheKey(spec)); ok { - return result.(*oidc.Provider) - } - return nil +type lruValidatorCacheEntry struct { + provider *oidc.Provider + client *http.Client } -func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) { - c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL) +func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) { + if result, ok := c.cache.Get(c.cacheKey(spec)); ok { + entry := result.(*lruValidatorCacheEntry) + return entry.provider, entry.client + } + return nil, nil +} + +func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) { + c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL) } func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} { @@ -95,8 +103,8 @@ type controller struct { providers idpinformers.UpstreamOIDCProviderInformer secrets corev1informers.SecretInformer validatorCache interface { - getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider - putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) + getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) + putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client) } } @@ -132,13 +140,13 @@ func (c *controller) Sync(ctx controllerlib.Context) error { } requeue := false - validatedUpstreams := make([]provider.UpstreamOIDCIdentityProvider, 0, len(actualUpstreams)) + validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams)) for _, upstream := range actualUpstreams { valid := c.validateUpstream(ctx, upstream) if valid == nil { requeue = true } else { - validatedUpstreams = append(validatedUpstreams, *valid) + validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid)) } } c.cache.SetIDPList(validatedUpstreams) @@ -150,10 +158,14 @@ func (c *controller) Sync(ctx controllerlib.Context) error { // validateUpstream validates the provided v1alpha1.UpstreamOIDCProvider and returns the validated configuration as a // provider.UpstreamOIDCIdentityProvider. As a side effect, it also updates the status of the v1alpha1.UpstreamOIDCProvider. -func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *provider.UpstreamOIDCIdentityProvider { - result := provider.UpstreamOIDCIdentityProvider{ - Name: upstream.Name, - Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes), +func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *upstreamoidc.ProviderConfig { + result := upstreamoidc.ProviderConfig{ + Name: upstream.Name, + Config: &oauth2.Config{ + Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes), + }, + UsernameClaim: upstream.Spec.Claims.Username, + GroupsClaim: upstream.Spec.Claims.Groups, } conditions := []*v1alpha1.Condition{ c.validateSecret(upstream, &result), @@ -180,7 +192,7 @@ func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alp } // validateSecret validates the .spec.client.secretName field and returns the appropriate ClientCredentialsValid condition. -func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition { +func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { secretName := upstream.Spec.Client.SecretName // Fetch the Secret from informer cache. @@ -217,7 +229,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res } // If everything is valid, update the result and set the condition to true. - result.ClientID = string(clientID) + result.Config.ClientID = string(clientID) + result.Config.ClientSecret = string(clientSecret) return &v1alpha1.Condition{ Type: typeClientCredsValid, Status: v1alpha1.ConditionTrue, @@ -227,9 +240,9 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res } // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. -func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition { - // Get the provider (from cache if possible). - discoveredProvider := c.validatorCache.getProvider(&upstream.Spec) +func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition { + // Get the provider and HTTP Client from cache if possible. + discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec) // If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache. if discoveredProvider == nil { @@ -242,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst Message: err.Error(), } } - httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} + httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer) if err != nil { @@ -255,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst } // Update the cache with the newly discovered value. - c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) + c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient) } // Parse out and validate the discovered authorize endpoint. @@ -278,7 +291,9 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst } // If everything is valid, update the result and set the condition to true. - result.AuthorizationURL = *authURL + result.Config.Endpoint = discoveredProvider.Endpoint() + result.Provider = discoveredProvider + result.Client = httpClient return &v1alpha1.Condition{ Type: typeOIDCDiscoverySucceeded, Status: v1alpha1.ConditionTrue, diff --git a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go index 949effaf..3ecfa91a 100644 --- a/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go +++ b/internal/controller/supervisorconfig/upstreamwatcher/upstreamwatcher_test.go @@ -24,9 +24,11 @@ import ( pinnipedfake "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned/fake" pinnipedinformers "go.pinniped.dev/generated/1.19/client/supervisor/informers/externalversions" "go.pinniped.dev/internal/controllerlib" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/testlogger" + "go.pinniped.dev/internal/upstreamoidc" ) func TestController(t *testing.T) { @@ -49,6 +51,8 @@ func TestController(t *testing.T) { testClientID = "test-oidc-client-id" testClientSecret = "test-oidc-client-secret" testValidSecretData = map[string][]byte{"clientID": []byte(testClientID), "clientSecret": []byte(testClientSecret)} + testGroupsClaim = "test-groups-claim" + testUsernameClaim = "test-username-claim" ) tests := []struct { name string @@ -56,7 +60,7 @@ func TestController(t *testing.T) { inputSecrets []runtime.Object wantErr string wantLogs []string - wantResultingCache []provider.UpstreamOIDCIdentityProvider + wantResultingCache []provider.UpstreamOIDCIdentityProviderI wantResultingUpstreams []v1alpha1.UpstreamOIDCProvider }{ { @@ -80,7 +84,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="secret \"test-client-secret\" not found" "name"="test-name" "namespace"="test-namespace" "reason"="SecretNotFound" "type"="ClientCredentialsValid"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -126,7 +130,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" has wrong type \"some-other-type\" (should be \"secrets.pinniped.dev/oidc-client\")" "name"="test-name" "namespace"="test-namespace" "reason"="SecretWrongType" "type"="ClientCredentialsValid"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -171,7 +175,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" is missing required keys [\"clientID\" \"clientSecret\"]" "name"="test-name" "namespace"="test-namespace" "reason"="SecretMissingKeys" "type"="ClientCredentialsValid"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -219,7 +223,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -267,7 +271,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: no certificates found" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: no certificates found" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -312,7 +316,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to perform OIDC discovery against \"invalid-url\"" "reason"="Unreachable" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to perform OIDC discovery against \"invalid-url\"" "name"="test-name" "namespace"="test-namespace" "reason"="Unreachable" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -358,7 +362,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -404,7 +408,7 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{}, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -437,6 +441,7 @@ func TestController(t *testing.T) { TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64}, Client: v1alpha1.OIDCClient{SecretName: testSecretName}, AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: append(testAdditionalScopes, "xyz", "openid")}, + Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim}, }, Status: v1alpha1.UpstreamOIDCProviderStatus{ Phase: "Error", @@ -455,12 +460,16 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{ - Name: testName, - ClientID: testClientID, - AuthorizationURL: *testIssuerAuthorizeURL, - Scopes: append(testExpectedScopes, "xyz"), - }}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{ + &oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: testName, + ClientID: testClientID, + AuthorizationURL: *testIssuerAuthorizeURL, + Scopes: append(testExpectedScopes, "xyz"), + UsernameClaim: testUsernameClaim, + GroupsClaim: testGroupsClaim, + }, + }, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -481,6 +490,7 @@ func TestController(t *testing.T) { TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64}, Client: v1alpha1.OIDCClient{SecretName: testSecretName}, AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: testAdditionalScopes}, + Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim}, }, Status: v1alpha1.UpstreamOIDCProviderStatus{ Phase: "Ready", @@ -499,12 +509,16 @@ func TestController(t *testing.T) { `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, }, - wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{ - Name: testName, - ClientID: testClientID, - AuthorizationURL: *testIssuerAuthorizeURL, - Scopes: testExpectedScopes, - }}, + wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{ + &oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: testName, + ClientID: testClientID, + AuthorizationURL: *testIssuerAuthorizeURL, + Scopes: testExpectedScopes, + UsernameClaim: testUsernameClaim, + GroupsClaim: testGroupsClaim, + }, + }, wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, Status: v1alpha1.UpstreamOIDCProviderStatus{ @@ -527,7 +541,9 @@ func TestController(t *testing.T) { kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) testLog := testlogger.New(t) cache := provider.NewDynamicUpstreamIDPProvider() - cache.SetIDPList([]provider.UpstreamOIDCIdentityProvider{{Name: "initial-entry"}}) + cache.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{ + &upstreamoidc.ProviderConfig{Name: "initial-entry"}, + }) controller := New( cache, @@ -551,7 +567,18 @@ func TestController(t *testing.T) { require.NoError(t, err) } require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n")) - require.ElementsMatch(t, tt.wantResultingCache, cache.GetIDPList()) + + actualIDPList := cache.GetIDPList() + require.Equal(t, len(tt.wantResultingCache), len(actualIDPList)) + for i := range actualIDPList { + actualIDP := actualIDPList[i].(*upstreamoidc.ProviderConfig) + require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName()) + require.Equal(t, tt.wantResultingCache[i].GetClientID(), actualIDP.GetClientID()) + require.Equal(t, tt.wantResultingCache[i].GetAuthorizationURL().String(), actualIDP.GetAuthorizationURL().String()) + require.Equal(t, tt.wantResultingCache[i].GetUsernameClaim(), actualIDP.GetUsernameClaim()) + require.Equal(t, tt.wantResultingCache[i].GetGroupsClaim(), actualIDP.GetGroupsClaim()) + require.ElementsMatch(t, tt.wantResultingCache[i].GetScopes(), actualIDP.GetScopes()) + } actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{}) require.NoError(t, err) diff --git a/internal/crud/crud.go b/internal/crud/crud.go index 82d32082..9e15581d 100644 --- a/internal/crud/crud.go +++ b/internal/crud/crud.go @@ -30,7 +30,7 @@ const ( ErrSecretTypeMismatch = constable.Error("secret storage data has incorrect type") ErrSecretLabelMismatch = constable.Error("secret storage data has incorrect label") - ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version") // TODO do we need this? + ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version") ) type Storage interface { @@ -139,7 +139,7 @@ func (s *secretsStorage) toSecret(signature, resourceVersion string, data JSON) Labels: map[string]string{ secretLabelKey: s.resource, // make it easier to find this stuff via kubectl }, - OwnerReferences: nil, // TODO we should set this to make sure stuff gets clean up + OwnerReferences: nil, }, Data: map[string][]byte{ secretDataKey: buf, diff --git a/internal/crud/crud_test.go b/internal/crud/crud_test.go index 93ee9818..cb0fc147 100644 --- a/internal/crud/crud_test.go +++ b/internal/crud/crud_test.go @@ -62,17 +62,17 @@ func TestStorage(t *testing.T) { }{ { name: "get non-existent", - resource: "authorization-codes", + resource: "authcode", mocks: nil, run: func(t *testing.T, storage Storage) error { _, err := storage.Get(ctx, "not-exists", nil) return err }, wantActions: []coretesting.Action{ - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-t2fx46yyvs3a"), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-t2fx46yyvs3a"), }, wantSecrets: nil, - wantErr: `failed to get authorization-codes for signature not-exists: secrets "pinniped-storage-authorization-codes-t2fx46yyvs3a" not found`, + wantErr: `failed to get authcode for signature not-exists: secrets "pinniped-storage-authcode-t2fx46yyvs3a" not found`, }, { name: "delete non-existent", diff --git a/internal/fosite/authorizationcode/authorizationcode_test.go b/internal/fosite/authorizationcode/authorizationcode_test.go deleted file mode 100644 index c434ff85..00000000 --- a/internal/fosite/authorizationcode/authorizationcode_test.go +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -package authorizationcode - -import ( - "context" - "crypto/ed25519" - "crypto/x509" - "encoding/json" - "math/rand" - "net/url" - "strings" - "testing" - "time" - - fuzz "github.com/google/gofuzz" - "github.com/ory/fosite" - "github.com/ory/fosite/handler/oauth2" - "github.com/ory/fosite/handler/openid" - "github.com/stretchr/testify/require" - "gopkg.in/square/go-jose.v2" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/client-go/kubernetes/fake" - coretesting "k8s.io/client-go/testing" -) - -func TestAuthorizeCodeStorage(t *testing.T) { - ctx := context.Background() - secretsGVR := schema.GroupVersionResource{ - Group: "", - Version: "v1", - Resource: "secrets", - } - - const namespace = "test-ns" - - type mocker interface { - AddReactor(verb, resource string, reaction coretesting.ReactionFunc) - PrependReactor(verb, resource string, reaction coretesting.ReactionFunc) - Tracker() coretesting.ObjectTracker - } - - tests := []struct { - name string - mocks func(*testing.T, mocker) - run func(*testing.T, oauth2.AuthorizeCodeStorage) error - wantActions []coretesting.Action - wantSecrets []corev1.Secret - wantErr string - }{ - { - name: "create, get, invalidate standard flow", - mocks: nil, - run: func(t *testing.T, storage oauth2.AuthorizeCodeStorage) error { - request := &fosite.AuthorizeRequest{ - ResponseTypes: fosite.Arguments{"not-code"}, - RedirectURI: &url.URL{ - Scheme: "", - Opaque: "weee", - User: &url.Userinfo{}, - Host: "", - Path: "/callback", - RawPath: "", - ForceQuery: false, - RawQuery: "", - Fragment: "", - }, - State: "stated", - HandledResponseTypes: fosite.Arguments{"not-type"}, - Request: fosite.Request{ - ID: "abcd-1", - RequestedAt: time.Time{}, - Client: &fosite.DefaultOpenIDConnectClient{ - DefaultClient: &fosite.DefaultClient{ - ID: "pinny", - Secret: nil, - RedirectURIs: nil, - GrantTypes: nil, - ResponseTypes: nil, - Scopes: nil, - Audience: nil, - Public: true, - }, - JSONWebKeysURI: "where", - JSONWebKeys: nil, - TokenEndpointAuthMethod: "something", - RequestURIs: nil, - RequestObjectSigningAlgorithm: "", - TokenEndpointAuthSigningAlgorithm: "", - }, - RequestedScope: nil, - GrantedScope: nil, - Form: url.Values{"key": []string{"val"}}, - Session: &openid.DefaultSession{ - Claims: nil, - Headers: nil, - ExpiresAt: nil, - Username: "snorlax", - Subject: "panda", - }, - RequestedAudience: nil, - GrantedAudience: nil, - }, - } - err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) - require.NoError(t, err) - - newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) - require.NoError(t, err) - require.Equal(t, request, newRequest) - - return storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature") - }, - wantActions: []coretesting.Action{ - coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4", - ResourceVersion: "", - Labels: map[string]string{ - "storage.pinniped.dev": "authorization-codes", - }, - }, - Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":true,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), - "pinniped-storage-version": []byte("1"), - }, - Type: "storage.pinniped.dev/authorization-codes", - }), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"), - coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"), - coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4", - ResourceVersion: "", - Labels: map[string]string{ - "storage.pinniped.dev": "authorization-codes", - }, - }, - Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), - "pinniped-storage-version": []byte("1"), - }, - Type: "storage.pinniped.dev/authorization-codes", - }), - }, - wantSecrets: []corev1.Secret{ - { - ObjectMeta: metav1.ObjectMeta{ - Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4", - Namespace: namespace, - ResourceVersion: "", - Labels: map[string]string{ - "storage.pinniped.dev": "authorization-codes", - }, - }, - Data: map[string][]byte{ - "pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), - "pinniped-storage-version": []byte("1"), - }, - Type: "storage.pinniped.dev/authorization-codes", - }, - }, - wantErr: "", - }, - } - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - client := fake.NewSimpleClientset() - if tt.mocks != nil { - tt.mocks(t, client) - } - secrets := client.CoreV1().Secrets(namespace) - storage := New(secrets) - - err := tt.run(t, storage) - - require.Equal(t, tt.wantErr, errString(err)) - require.Equal(t, tt.wantActions, client.Actions()) - - actualSecrets, err := secrets.List(ctx, metav1.ListOptions{}) - require.NoError(t, err) - require.Equal(t, tt.wantSecrets, actualSecrets.Items) - }) - } -} - -func errString(err error) string { - if err == nil { - return "" - } - - return err.Error() -} - -// TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session. -// It will detect any changes to fosite.AuthorizeRequest and guarantees that all interface types have concrete implementations. -func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) { - validSession := NewValidEmptyAuthorizeCodeSession() - - // sanity check our valid session - extractedRequest, err := validateAndExtractAuthorizeRequest(validSession.Request) - require.NoError(t, err) - require.Equal(t, validSession.Request, extractedRequest) - - // checked above - defaultClient := validSession.Request.Request.Client.(*fosite.DefaultOpenIDConnectClient) - defaultSession := validSession.Request.Request.Session.(*openid.DefaultSession) - - // makes it easier to use a raw string - replacer := strings.NewReplacer("`", "a") - randString := func(c fuzz.Continue) string { - for { - s := c.RandString() - if len(s) == 0 { - continue // skip empty string - } - return replacer.Replace(s) - } - } - - // deterministic fuzzing of fosite.AuthorizeRequest - f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs( - // these functions guarantee that these are the only interface types we need to fill out - // if fosite.AuthorizeRequest changes to add more, the fuzzer will panic - func(fc *fosite.Client, c fuzz.Continue) { - c.Fuzz(defaultClient) - *fc = defaultClient - }, - func(fs *fosite.Session, c fuzz.Continue) { - c.Fuzz(defaultSession) - *fs = defaultSession - }, - - // these types contain an interface{} that we need to handle - // this is safe because we explicitly provide the openid.DefaultSession concrete type - func(value *map[string]interface{}, c fuzz.Continue) { - // cover all the JSON data types just in case - *value = map[string]interface{}{ - randString(c): float64(c.Intn(1 << 32)), - randString(c): map[string]interface{}{ - randString(c): []interface{}{float64(c.Intn(1 << 32))}, - randString(c): map[string]interface{}{ - randString(c): nil, - randString(c): map[string]interface{}{ - randString(c): c.RandBool(), - }, - }, - }, - } - }, - // JWK contains an interface{} Key that we need to handle - // this is safe because JWK explicitly implements JSON marshalling and unmarshalling - func(jwk *jose.JSONWebKey, c fuzz.Continue) { - key, _, err := ed25519.GenerateKey(c) - require.NoError(t, err) - jwk.Key = key - - // set these fields to make the .Equal comparison work - jwk.Certificates = []*x509.Certificate{} - jwk.CertificatesURL = &url.URL{} - jwk.CertificateThumbprintSHA1 = []byte{} - jwk.CertificateThumbprintSHA256 = []byte{} - }, - - // set this to make the .Equal comparison work - // this is safe because Time explicitly implements JSON marshalling and unmarshalling - func(tp *time.Time, c fuzz.Continue) { - *tp = time.Unix(c.Int63n(1<<32), c.Int63n(1<<32)).UTC() - }, - - // make random strings that do not contain any ` characters - func(s *string, c fuzz.Continue) { - *s = randString(c) - }, - // handle string type alias - func(s *fosite.TokenType, c fuzz.Continue) { - *s = fosite.TokenType(randString(c)) - }, - // handle string type alias - func(s *fosite.Arguments, c fuzz.Continue) { - n := c.Intn(3) + 1 // 1 to 3 items - arguments := make(fosite.Arguments, n) - for i := range arguments { - arguments[i] = randString(c) - } - *s = arguments - }, - ) - - f.Fuzz(validSession) - - const name = "fuzz" // value is irrelevant - ctx := context.Background() - secrets := fake.NewSimpleClientset().CoreV1().Secrets(name) - storage := New(secrets) - - // issue a create using the fuzzed request to confirm that marshalling works - err = storage.CreateAuthorizeCodeSession(ctx, name, validSession.Request) - require.NoError(t, err) - - // retrieve a copy of the fuzzed request from storage to confirm that unmarshalling works - newRequest, err := storage.GetAuthorizeCodeSession(ctx, name, nil) - require.NoError(t, err) - - // the fuzzed request and the copy from storage should be exactly the same - require.Equal(t, validSession.Request, newRequest) - - secretList, err := secrets.List(ctx, metav1.ListOptions{}) - require.NoError(t, err) - require.Len(t, secretList.Items, 1) - authorizeCodeSessionJSONFromStorage := string(secretList.Items[0].Data["pinniped-storage-data"]) - - // set these to match CreateAuthorizeCodeSession so that .JSONEq works - validSession.Active = true - validSession.Version = "1" - - validSessionJSONBytes, err := json.MarshalIndent(validSession, "", "\t") - require.NoError(t, err) - authorizeCodeSessionJSONFromFuzzing := string(validSessionJSONBytes) - - // the fuzzed session and storage session should have identical JSON - require.JSONEq(t, authorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromStorage) - - // while the fuzzer will panic if AuthorizeRequest changes in a way that cannot be fuzzed, - // if it adds a new field that can be fuzzed, this check will fail - // thus if AuthorizeRequest changes, we will detect it here (though we could possibly miss an omitempty field) - require.Equal(t, ExpectedAuthorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromFuzzing) -} diff --git a/internal/fosite/authorizationcode/authorizationcode.go b/internal/fositestorage/authorizationcode/authorizationcode.go similarity index 52% rename from internal/fosite/authorizationcode/authorizationcode.go rename to internal/fositestorage/authorizationcode/authorizationcode.go index 917c5cc0..8aca618a 100644 --- a/internal/fosite/authorizationcode/authorizationcode.go +++ b/internal/fositestorage/authorizationcode/authorizationcode.go @@ -16,11 +16,11 @@ import ( "go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/crud" + "go.pinniped.dev/internal/fositestorage" ) const ( - ErrInvalidAuthorizeRequestType = constable.Error("authorization request must be of type fosite.AuthorizeRequest") - ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must not be nil") + ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must be present") ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version") authorizeCodeStorageVersion = "1" @@ -33,26 +33,25 @@ type authorizeCodeStorage struct { } type AuthorizeCodeSession struct { - Active bool `json:"active"` - Request *fosite.AuthorizeRequest `json:"request"` - Version string `json:"version"` + Active bool `json:"active"` + Request *fosite.Request `json:"request"` + Version string `json:"version"` } func New(secrets corev1client.SecretInterface) oauth2.AuthorizeCodeStorage { - return &authorizeCodeStorage{storage: crud.New("authorization-codes", secrets)} + return &authorizeCodeStorage{storage: crud.New("authcode", secrets)} } func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { - // this conversion assumes that we do not wrap the default type in any way + // This conversion assumes that we do not wrap the default type in any way // i.e. we use the default fosite.OAuth2Provider.NewAuthorizeRequest implementation // note that because this type is serialized and stored in Kube, we cannot easily change the implementation later - // TODO hydra uses the fosite.Request struct and ignores the extra fields in fosite.AuthorizeRequest - request, err := validateAndExtractAuthorizeRequest(requester) + request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester) if err != nil { return err } - // TODO hydra stores specific fields from the requester + // Note, in case it is helpful, that Hydra stores specific fields from the requester: // request ID // requestedAt // OAuth client ID @@ -70,12 +69,11 @@ func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, s } func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { - // TODO hydra uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes - - // TODO hydra gets the client from its DB as a concrete type via client ID, - // the hydra memory client just validates that the client ID exists - - // TODO hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length + // Note, in case it is helpful, that Hydra: + // - uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes + // - gets the client from its DB as a concrete type via client ID, the hydra memory client just validates that the + // client ID exists + // - hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length session, _, err := a.getSession(ctx, signature) @@ -88,8 +86,6 @@ func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, sign } func (a *authorizeCodeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error { - // TODO write garbage collector for these codes - session, rv, err := a.getSession(ctx, signature) if err != nil { return err @@ -123,7 +119,7 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string) ErrInvalidAuthorizeRequestVersion, signature, version, authorizeCodeStorageVersion) } - if session.Request == nil { + if session.Request.ID == "" { return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData) } @@ -137,31 +133,13 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string) func NewValidEmptyAuthorizeCodeSession() *AuthorizeCodeSession { return &AuthorizeCodeSession{ - Request: &fosite.AuthorizeRequest{ - Request: fosite.Request{ - Client: &fosite.DefaultOpenIDConnectClient{}, - Session: &openid.DefaultSession{}, - }, + Request: &fosite.Request{ + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, }, } } -func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.AuthorizeRequest, error) { - request, ok1 := requester.(*fosite.AuthorizeRequest) - if !ok1 { - return nil, ErrInvalidAuthorizeRequestType - } - _, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient) - _, ok3 := request.Session.(*openid.DefaultSession) - - valid := ok2 && ok3 - if !valid { - return nil, ErrInvalidAuthorizeRequestType - } - - return request, nil -} - var _ interface { Is(error) bool Unwrap() error @@ -189,59 +167,37 @@ func (e *errSerializationFailureWithCause) Error() string { const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ "active": true, "request": { - "responseTypes": [ - "¥Îʒ襧.ɕ7崛瀇莒AȒ[ɠ牐7#$ɭ", - ".5ȿELj9ûF済(D疻翋膗", - "螤Yɫüeɯ紤邥翔勋\\RBʒ;-" - ], - "redirectUri": { - "Scheme": "ħesƻU赒M喦_ģ", - "Opaque": "Ġ/_章Ņ缘T蝟NJ儱礹燃ɢ", - "User": {}, - "Host": "ȳ4螘Wo", - "Path": "}i{", - "RawPath": "5Dža丝eF0eė鱊hǒx蔼Q", - "ForceQuery": true, - "RawQuery": "熤1bbWV", - "Fragment": "ȋc剠鏯ɽÿ¸", - "RawFragment": "qƤ" - }, - "state": "@n,x竘Şǥ嗾稀'ã击漰怼禝穞梠Ǫs", - "handledResponseTypes": [ - "m\"e尚鬞ƻɼ抹d誉y鿜Ķ" - ], - "id": "ō澩ć|3U2Ǜl霨ǦǵpƉ", - "requestedAt": "1989-11-05T22:02:31.105295894Z", + "id": "嫎l蟲aƖ啘艿", + "requestedAt": "2082-11-10T18:36:11.627253638Z", "client": { - "id": "[:c顎疻紵D", - "client_secret": "mQ==", + "id": "!ſɄĈp[述齛ʘUȻ.5ȿE", + "client_secret": "UQ==", "redirect_uris": [ - "恣S@T嵇LJV,Æ櫔袆鋹奘菲", - "ãƻʚ肈ą8O+a駣Ʉɼk瘸'鴵y" + "ǣ珑 ʑ飶畛Ȳ螤Yɫüeɯ紤邥翔勋\\", + "Bʒ;", + "鿃攴Ųęʍ鎾ʦ©cÏN,Ġ/_" ], "grant_types": [ - ".湆ê\"唐", - "曎餄FxD溪躲珫ÈşɜȨû臓嬣\"ǃŤz" + "憉sHĒ尥窘挼Ŀʼn" ], "response_types": [ - "Ņʘʟ車sʊ儓JǐŪɺǣy|耑ʄ" + "4", + "ʄÔ@}i{絧遗Ū^ȝĸ谋Vʋ鱴閇T" ], "scopes": [ - "Ą", - "萙Į(潶饏熞ĝƌĆ1", - "əȤ4Į筦p煖鵄$睱奐耡q" + "R鴝順諲ŮŚ节ȭŀȋc剠鏯ɽÿ¸" ], "audience": [ - "Ʃǣ鿫/Ò敫ƤV" + "Ƥ" ], "public": true, - "jwks_uri": "ȩđ[嬧鱒Ȁ彆媚杨嶒ĤG", + "jwks_uri": "BA瘪囷ɫCʄɢ雐譄uée'", "jwks": { "keys": [ { "kty": "OKP", "crv": "Ed25519", - "x": "JmA-6KpjzqKu0lq9OiB6ORL4s2UzBFPsE1hm6vESeXM", + "x": "nK9xgX_iN7u3u_i8YOO7ZRT_WK028Vd_nhtsUu7Eo6E", "x5u": { "Scheme": "", "Opaque": "", @@ -258,24 +214,7 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ { "kty": "OKP", "crv": "Ed25519", - "x": "LbRC1_3HEe5o7Japk9jFp3_7Ou7Gi2gpqrVrIi0eLDQ", - "x5u": { - "Scheme": "", - "Opaque": "", - "User": null, - "Host": "", - "Path": "", - "RawPath": "", - "ForceQuery": false, - "RawQuery": "", - "Fragment": "", - "RawFragment": "" - } - }, - { - "kty": "OKP", - "crv": "Ed25519", - "x": "Ovk4DF8Yn3mkULuTqnlGJxFnKGu9EL6Xcf2Nql9lK3c", + "x": "UbbswQgzWhfGCRlwQmMp6fw_HoIoqkIaKT-2XN2fuYU", "x5u": { "Scheme": "", "Opaque": "", @@ -291,91 +230,95 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ } ] }, - "token_endpoint_auth_method": "\u0026(K鵢Kj ŏ9Q韉Ķ%嶑輫ǘ(", + "token_endpoint_auth_method": "ŚǗƳȕ暭Q0ņP羾,塐", "request_uris": [ - ":", - "6ě#嫀^xz Ū胧r" + "lj翻LH^俤µDzɹ@©|\u003eɃ", + "[:c顎疻紵D" ], - "request_object_signing_alg": "^¡!犃ĹĐJí¿ō擫ų懫砰¿", - "token_endpoint_auth_signing_alg": "ƈŮå" + "request_object_signing_alg": "m1Ì恣S@T嵇LJV,Æ櫔袆鋹奘", + "token_endpoint_auth_signing_alg": "Fãƻʚ肈ą8O+a駣" }, "scopes": [ - "阃.Ù頀ʌGa皶竇瞍涘¹", - "ȽŮ切衖庀ŰŒ矠", - "楓)馻řĝǕ菸Tĕ1伞柲\u003c\"ʗȆ\\雤" + "ɼk瘸'鴵yſǮŁ±\u003eFA曎餄FxD溪", + "綻N镪p赌h%桙dĽ" ], "grantedScopes": [ - "ơ鮫R嫁ɍUƞ9+u!Ȱ", - "}Ă岜" + "癗E]Ņʘʟ車s" ], "form": { - "旸Ť/Õ薝隧;綡,鼞纂=": [ - "[滮]憀", - "3\u003eÙœ蓄UK嗤眇疟Țƒ1v¸KĶ" + "蹬器ķ8ŷ萒寎廭#疶昄Ą-Ƃƞ轵": [ + "熞ĝƌĆ1ȇyǴ濎=Tʉȼʁŀ\u003c", + "耡q戨稞R÷mȵg釽[ƞ@", + "đ[嬧鱒Ȁ彆媚杨嶒ĤGÀ吧Lŷ" + ], + "餟": [ + "蒍z\u0026(K鵢Kj ŏ9Q韉Ķ%", + "輫ǘ(¨Ƞ亱6ě#嫀^xz ", + "@耢ɝ^¡!犃ĹĐJí¿ō擫" ] }, "session": { "Claims": { - "JTI": "};Ų斻遟a衪荖舃", - "Issuer": "芠顋敀拲h蝺$!", - "Subject": "}j%(=ſ氆]垲莲顇", + "JTI": "懫砰¿C筽娴ƓaPu镈賆ŗɰ", + "Issuer": "皶竇瞍涘¹焕iǢǽɽĺŧ", + "Subject": "矠M6ɡǜg炾ʙ$%o6肿Ȫ", "Audience": [ - "彑V\\廳蟕Țǡ蔯ʠ浵Ī龉磈螖畭5", - "渇Ȯʕc" + "ƌÙ鯆GQơ鮫R嫁ɍUƞ9+u!Ȱ踾$" ], - "Nonce": "Ǖ=rlƆ褡{ǏS", - "ExpiresAt": "1975-11-17T14:21:34.205609651Z", - "IssuedAt": "2104-07-03T15:40:03.66710966Z", - "RequestedAt": "2031-05-18T05:14:19.449350555Z", - "AuthTime": "2018-01-27T07:55:06.056862114Z", - "AccessTokenHash": "鹰肁躧", - "AuthenticationContextClassReference": "}Ɇ", - "AuthenticationMethodsReference": "DQh:uȣ", - "CodeHash": "ɘȏıȒ諃龟", + "Nonce": "us旸Ť/Õ薝隧;綡,鼞", + "ExpiresAt": "2065-11-30T13:47:03.613000626Z", + "IssuedAt": "1976-02-22T09:57:20.479850437Z", + "RequestedAt": "2016-04-13T04:18:53.648949323Z", + "AuthTime": "2098-07-12T04:38:54.034043015Z", + "AccessTokenHash": "滮]", + "AuthenticationContextClassReference": "°3\u003eÙ", + "AuthenticationMethodsReference": "k?µ鱔ǤÂ", + "CodeHash": "Țƒ1v¸KĶ跭};", "Extra": { - "a": { - "^i臏f恡ƨ彮": { - "DĘ敨ýÏʥZq7烱藌\\": null, - "V": { - "őŧQĝ微'X焌襱ǭɕņ殥!_n": false - } - }, - "Ż猁": [ - 1706822246 - ] + "=ſ氆": { + "Ƿī,廖ʡ彑V\\廳蟕Ț": [ + 843216989 + ], + "蔯ʠ浵Ī": { + "H\"nǕ=rlƆ褡{ǏSȳŅ": { + "Žg": false + }, + "枱鰧ɛ鸁A渇": null + } }, - "Ò椪)ɫqň2搞Ŀ高摠鲒鿮禗O": 1233332227 + "斻遟a衪荖舃9闄岈锘肺ńʥƕU}j%": 2520197933 } }, "Headers": { "Extra": { - "?戋璖$9\u0026": { - "µcɕ餦ÑEǰ哤癨浦浏1R": [ - 3761201123 - ], - "頓ć§蚲6rǦ\u003cqċ": { - "Łʀ§ȏœɽDz斡冭ȸěaʜD捛?½ʀ+": null, - "ɒúIJ誠ƉyÖ.峷1藍殙菥趏": { - "jHȬȆ#)\u003cX": true + "熒ɘȏıȒ諃龟ŴŠ'耐Ƭ扵ƹ玄ɕwL": { + "ýÏʥZq7烱藌\\捀¿őŧQ": { + "微'X焌襱ǭɕņ殥!_": null, + "荇届UȚ?戋璖$9\u00269舋": { + "ɕ餦ÑEǰ哤癨浦浏1Rk頓ć§蚲6": true } - } + }, + "鲒鿮禗O暒aJP鐜?ĮV嫎h譭ȉ]DĘ": [ + 954647573 + ] }, - "U": 1354158262 + "皩Ƭ}Ɇ.雬Ɨ´唁": 1572524915 } }, "ExpiresAt": { - "\"嘬ȹĹaó剺撱Ȱ": "1985-09-09T04:35:40.533197189Z", - "ʆ\u003e": "1998-08-07T05:37:11.759718906Z", - "柏ʒ鴙*鸆偡Ȓ肯Ûx": "2036-12-19T06:36:14.414805124Z" + "\u003cqċ譈8ŪɎP绿MÅ": "2031-10-18T22:07:34.950803105Z", + "ȸěaʜD捛?½ʀ+Ċ偢镳ʬÍɷȓ\u003c": "2049-05-13T15:27:20.968432454Z" }, - "Username": "qmʎaðƠ绗ʢ緦Hū", - "Subject": "屾Ê窢ɋ鄊qɠ谫ǯǵƕ牀1鞊\\ȹ)" + "Username": "1藍殙菥趏酱Nʎ\u0026^横懋ƶ峦Fïȫƅw", + "Subject": "檾ĩĆ爨4犹|v炩f柏ʒ鴙*鸆偡" }, "requestedAudience": [ - "鉍商OɄƣ圔,xĪɏV鵅砍" + "肯Ûx穞Ƀ", + "ź蕴3ǐ薝Ƅ腲=ʐ诂鱰屾Ê窢ɋ鄊qɠ谫" ], "grantedAudience": [ - "C笜嚯\u003cǐšɚĀĥʋ6鉅\\þc涎漄Ɨ腼" + "ǵƕ牀1鞊\\ȹ)}鉍商OɄƣ圔,xĪ", + "悾xn冏裻摼0Ʈ蚵Ȼ塕»£#稏扟X" ] }, "version": "1" diff --git a/internal/fositestorage/authorizationcode/authorizationcode_test.go b/internal/fositestorage/authorizationcode/authorizationcode_test.go new file mode 100644 index 00000000..616eb2de --- /dev/null +++ b/internal/fositestorage/authorizationcode/authorizationcode_test.go @@ -0,0 +1,401 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package authorizationcode + +import ( + "context" + "crypto/ed25519" + "crypto/x509" + "encoding/json" + "fmt" + "math/rand" + "net/url" + "strings" + "testing" + "time" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" + + fuzz "github.com/google/gofuzz" + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + kubetesting "k8s.io/client-go/testing" + + "go.pinniped.dev/internal/fositestorage" +) + +const namespace = "test-ns" + +func TestAuthorizationCodeStorage(t *testing.T) { + ctx := context.Background() + secretsGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "secrets", + } + + wantActions := []kubetesting.Action{ + kubetesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"active":true,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + }), + kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"), + kubetesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + }), + } + + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, + }, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", + }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, + } + err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) + require.NoError(t, err) + + newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + require.NoError(t, err) + require.Equal(t, request, newRequest) + + err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature") + require.NoError(t, err) + + require.Equal(t, wantActions, client.Actions()) + + // Doing a Get on an invalidated session should still return the session, but also return an error. + invalidatedRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + require.EqualError(t, err, "authorization code session for fancy-signature has already been used: Authorization code has ben invalidated") + require.Equal(t, "abcd-1", invalidatedRequest.GetID()) +} + +func TestGetNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + _, notFoundErr := storage.GetAuthorizeCodeSession(ctx, "non-existent-signature", nil) + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestInvalidateWhenNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + notFoundErr := storage.InvalidateAuthorizeCodeSession(ctx, "non-existent-signature") + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestInvalidateWhenConflictOnUpdateHappens(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + client.PrependReactor("update", "secrets", func(_ kubetesting.Action) (bool, runtime.Object, error) { + return true, nil, apierrors.NewConflict(schema.GroupResource{ + Group: "", + Resource: "secrets", + }, "some-secret-name", fmt.Errorf("there was a conflict")) + }) + + request := &fosite.Request{ + ID: "some-request-id", + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, + } + err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request) + require.NoError(t, err) + err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature") + require.EqualError(t, err, `The request could not be completed due to concurrent access: failed to update authcode for signature fancy-signature at resource version : Operation cannot be fulfilled on secrets "some-secret-name": there was a conflict`) +} + +func TestWrongVersion(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version", "active": true}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + } + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + + require.EqualError(t, err, "authorization request data has wrong version: authorization code session for fancy-signature has version not-the-right-version instead of 1") +} + +func TestNilSessionRequest(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "authcode", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value", "version":"1", "active": true}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/authcode", + } + + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil) + require.EqualError(t, err, "malformed authorization code session for fancy-signature: authorization request data must be present") +} + +func TestCreateWithNilRequester(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", nil) + require.EqualError(t, err, "requester must be of type fosite.Request") +} + +func TestCreateWithWrongRequesterDataTypes(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + Session: nil, + Client: &fosite.DefaultOpenIDConnectClient{}, + } + err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's session must be of type openid.DefaultSession") + + request = &fosite.Request{ + Session: &openid.DefaultSession{}, + Client: nil, + } + err = storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient") +} + +// TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session. +// It will detect any changes to fosite.AuthorizeRequest and guarantees that all interface types have concrete implementations. +func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) { + validSession := NewValidEmptyAuthorizeCodeSession() + + // sanity check our valid session + extractedRequest, err := fositestorage.ValidateAndExtractAuthorizeRequest(validSession.Request) + require.NoError(t, err) + require.Equal(t, validSession.Request, extractedRequest) + + // checked above + defaultClient := validSession.Request.Client.(*fosite.DefaultOpenIDConnectClient) + defaultSession := validSession.Request.Session.(*openid.DefaultSession) + + // makes it easier to use a raw string + replacer := strings.NewReplacer("`", "a") + randString := func(c fuzz.Continue) string { + for { + s := c.RandString() + if len(s) == 0 { + continue // skip empty string + } + return replacer.Replace(s) + } + } + + // deterministic fuzzing of fosite.Request + f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs( + // these functions guarantee that these are the only interface types we need to fill out + // if fosite.Request changes to add more, the fuzzer will panic + func(fc *fosite.Client, c fuzz.Continue) { + c.Fuzz(defaultClient) + *fc = defaultClient + }, + func(fs *fosite.Session, c fuzz.Continue) { + c.Fuzz(defaultSession) + *fs = defaultSession + }, + + // these types contain an interface{} that we need to handle + // this is safe because we explicitly provide the openid.DefaultSession concrete type + func(value *map[string]interface{}, c fuzz.Continue) { + // cover all the JSON data types just in case + *value = map[string]interface{}{ + randString(c): float64(c.Intn(1 << 32)), + randString(c): map[string]interface{}{ + randString(c): []interface{}{float64(c.Intn(1 << 32))}, + randString(c): map[string]interface{}{ + randString(c): nil, + randString(c): map[string]interface{}{ + randString(c): c.RandBool(), + }, + }, + }, + } + }, + // JWK contains an interface{} Key that we need to handle + // this is safe because JWK explicitly implements JSON marshalling and unmarshalling + func(jwk *jose.JSONWebKey, c fuzz.Continue) { + key, _, err := ed25519.GenerateKey(c) + require.NoError(t, err) + jwk.Key = key + + // set these fields to make the .Equal comparison work + jwk.Certificates = []*x509.Certificate{} + jwk.CertificatesURL = &url.URL{} + jwk.CertificateThumbprintSHA1 = []byte{} + jwk.CertificateThumbprintSHA256 = []byte{} + }, + + // set this to make the .Equal comparison work + // this is safe because Time explicitly implements JSON marshalling and unmarshalling + func(tp *time.Time, c fuzz.Continue) { + *tp = time.Unix(c.Int63n(1<<32), c.Int63n(1<<32)).UTC() + }, + + // make random strings that do not contain any ` characters + func(s *string, c fuzz.Continue) { + *s = randString(c) + }, + // handle string type alias + func(s *fosite.TokenType, c fuzz.Continue) { + *s = fosite.TokenType(randString(c)) + }, + // handle string type alias + func(s *fosite.Arguments, c fuzz.Continue) { + n := c.Intn(3) + 1 // 1 to 3 items + arguments := make(fosite.Arguments, n) + for i := range arguments { + arguments[i] = randString(c) + } + *s = arguments + }, + ) + + f.Fuzz(validSession) + + const name = "fuzz" // value is irrelevant + ctx := context.Background() + secrets := fake.NewSimpleClientset().CoreV1().Secrets(name) + storage := New(secrets) + + // issue a create using the fuzzed request to confirm that marshalling works + err = storage.CreateAuthorizeCodeSession(ctx, name, validSession.Request) + require.NoError(t, err) + + // retrieve a copy of the fuzzed request from storage to confirm that unmarshalling works + newRequest, err := storage.GetAuthorizeCodeSession(ctx, name, nil) + require.NoError(t, err) + + // the fuzzed request and the copy from storage should be exactly the same + require.Equal(t, validSession.Request, newRequest) + + secretList, err := secrets.List(ctx, metav1.ListOptions{}) + require.NoError(t, err) + require.Len(t, secretList.Items, 1) + authorizeCodeSessionJSONFromStorage := string(secretList.Items[0].Data["pinniped-storage-data"]) + + // set these to match CreateAuthorizeCodeSession so that .JSONEq works + validSession.Active = true + validSession.Version = "1" + + validSessionJSONBytes, err := json.MarshalIndent(validSession, "", "\t") + require.NoError(t, err) + authorizeCodeSessionJSONFromFuzzing := string(validSessionJSONBytes) + + // the fuzzed session and storage session should have identical JSON + require.JSONEq(t, authorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromStorage) + + // while the fuzzer will panic if AuthorizeRequest changes in a way that cannot be fuzzed, + // if it adds a new field that can be fuzzed, this check will fail + // thus if AuthorizeRequest changes, we will detect it here (though we could possibly miss an omitempty field) + require.Equal(t, ExpectedAuthorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromFuzzing) +} diff --git a/internal/fositestorage/fositestorage.go b/internal/fositestorage/fositestorage.go new file mode 100644 index 00000000..d23c9f6a --- /dev/null +++ b/internal/fositestorage/fositestorage.go @@ -0,0 +1,34 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package fositestorage + +import ( + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + + "go.pinniped.dev/internal/constable" +) + +const ( + ErrInvalidRequestType = constable.Error("requester must be of type fosite.Request") + ErrInvalidClientType = constable.Error("requester's client must be of type fosite.DefaultOpenIDConnectClient") + ErrInvalidSessionType = constable.Error("requester's session must be of type openid.DefaultSession") +) + +func ValidateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) { + request, ok1 := requester.(*fosite.Request) + if !ok1 { + return nil, ErrInvalidRequestType + } + _, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient) + if !ok2 { + return nil, ErrInvalidClientType + } + _, ok3 := request.Session.(*openid.DefaultSession) + if !ok3 { + return nil, ErrInvalidSessionType + } + + return request, nil +} diff --git a/internal/fositestorage/openidconnect/openidconnect.go b/internal/fositestorage/openidconnect/openidconnect.go new file mode 100644 index 00000000..797d21a8 --- /dev/null +++ b/internal/fositestorage/openidconnect/openidconnect.go @@ -0,0 +1,124 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package openidconnect + +import ( + "context" + "fmt" + "strings" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "k8s.io/apimachinery/pkg/api/errors" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + + "go.pinniped.dev/internal/constable" + "go.pinniped.dev/internal/crud" + "go.pinniped.dev/internal/fositestorage" +) + +const ( + ErrInvalidOIDCRequestVersion = constable.Error("oidc request data has wrong version") + ErrInvalidOIDCRequestData = constable.Error("oidc request data must be present") + ErrMalformedAuthorizationCode = constable.Error("malformed authorization code") + + oidcStorageVersion = "1" +) + +var _ openid.OpenIDConnectRequestStorage = &openIDConnectRequestStorage{} + +type openIDConnectRequestStorage struct { + storage crud.Storage +} + +type session struct { + Request *fosite.Request `json:"request"` + Version string `json:"version"` +} + +func New(secrets corev1client.SecretInterface) openid.OpenIDConnectRequestStorage { + return &openIDConnectRequestStorage{storage: crud.New("oidc", secrets)} +} + +func (a *openIDConnectRequestStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error { + signature, err := getSignature(authcode) + if err != nil { + return err + } + + request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester) + if err != nil { + return err + } + + _, err = a.storage.Create(ctx, signature, &session{Request: request, Version: oidcStorageVersion}) + return err +} + +func (a *openIDConnectRequestStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, _ fosite.Requester) (fosite.Requester, error) { + signature, err := getSignature(authcode) + if err != nil { + return nil, err + } + + session, _, err := a.getSession(ctx, signature) + + if err != nil { + return nil, err + } + + return session.Request, err +} + +func (a *openIDConnectRequestStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error { + signature, err := getSignature(authcode) + if err != nil { + return err + } + + return a.storage.Delete(ctx, signature) +} + +func (a *openIDConnectRequestStorage) getSession(ctx context.Context, signature string) (*session, string, error) { + session := newValidEmptyOIDCSession() + rv, err := a.storage.Get(ctx, signature, session) + + if errors.IsNotFound(err) { + return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error()) + } + + if err != nil { + return nil, "", fmt.Errorf("failed to get oidc session for %s: %w", signature, err) + } + + if version := session.Version; version != oidcStorageVersion { + return nil, "", fmt.Errorf("%w: oidc session for %s has version %s instead of %s", + ErrInvalidOIDCRequestVersion, signature, version, oidcStorageVersion) + } + + if session.Request.ID == "" { + return nil, "", fmt.Errorf("malformed oidc session for %s: %w", signature, ErrInvalidOIDCRequestData) + } + + return session, rv, nil +} + +func newValidEmptyOIDCSession() *session { + return &session{ + Request: &fosite.Request{ + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, + }, + } +} + +func getSignature(authorizationCode string) (string, error) { + split := strings.Split(authorizationCode, ".") + + if len(split) != 2 { + return "", ErrMalformedAuthorizationCode + } + + return split[1], nil +} diff --git a/internal/fositestorage/openidconnect/openidconnect_test.go b/internal/fositestorage/openidconnect/openidconnect_test.go new file mode 100644 index 00000000..976828ed --- /dev/null +++ b/internal/fositestorage/openidconnect/openidconnect_test.go @@ -0,0 +1,209 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package openidconnect + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + coretesting "k8s.io/client-go/testing" +) + +const namespace = "test-ns" + +func TestOpenIdConnectStorage(t *testing.T) { + ctx := context.Background() + secretsGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "secrets", + } + + wantActions := []coretesting.Action{ + coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "oidc", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/oidc", + }), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"), + coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"), + } + + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, + }, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", + }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, + } + err := storage.CreateOpenIDConnectSession(ctx, "fancy-code.fancy-signature", request) + require.NoError(t, err) + + newRequest, err := storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil) + require.NoError(t, err) + require.Equal(t, request, newRequest) + + err = storage.DeleteOpenIDConnectSession(ctx, "fancy-code.fancy-signature") + require.NoError(t, err) + + require.Equal(t, wantActions, client.Actions()) +} + +func TestGetNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + _, notFoundErr := storage.GetOpenIDConnectSession(ctx, "authcode.non-existent-signature", nil) + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestWrongVersion(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "oidc", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/oidc", + } + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil) + + require.EqualError(t, err, "oidc request data has wrong version: oidc session for fancy-signature has version not-the-right-version instead of 1") +} + +func TestNilSessionRequest(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "oidc", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/oidc", + } + + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil) + require.EqualError(t, err, "malformed oidc session for fancy-signature: oidc request data must be present") +} + +func TestCreateWithNilRequester(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", nil) + require.EqualError(t, err, "requester must be of type fosite.Request") +} + +func TestCreateWithWrongRequesterDataTypes(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + Session: nil, + Client: &fosite.DefaultOpenIDConnectClient{}, + } + err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request) + require.EqualError(t, err, "requester's session must be of type openid.DefaultSession") + + request = &fosite.Request{ + Session: &openid.DefaultSession{}, + Client: nil, + } + err = storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request) + require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient") +} + +func TestAuthcodeHasNoDot(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreateOpenIDConnectSession(ctx, "all-one-part", nil) + require.EqualError(t, err, "malformed authorization code") +} diff --git a/internal/fositestorage/pkce/pkce.go b/internal/fositestorage/pkce/pkce.go new file mode 100644 index 00000000..9e8ef3d5 --- /dev/null +++ b/internal/fositestorage/pkce/pkce.go @@ -0,0 +1,98 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package pkce + +import ( + "context" + "fmt" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/handler/pkce" + "k8s.io/apimachinery/pkg/api/errors" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + + "go.pinniped.dev/internal/constable" + "go.pinniped.dev/internal/crud" + "go.pinniped.dev/internal/fositestorage" +) + +const ( + ErrInvalidPKCERequestVersion = constable.Error("pkce request data has wrong version") + ErrInvalidPKCERequestData = constable.Error("pkce request data must be present") + + pkceStorageVersion = "1" +) + +var _ pkce.PKCERequestStorage = &pkceStorage{} + +type pkceStorage struct { + storage crud.Storage +} + +type session struct { + Request *fosite.Request `json:"request"` + Version string `json:"version"` +} + +func New(secrets corev1client.SecretInterface) pkce.PKCERequestStorage { + return &pkceStorage{storage: crud.New("pkce", secrets)} +} + +func (a *pkceStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { + request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester) + if err != nil { + return err + } + + _, err = a.storage.Create(ctx, signature, &session{Request: request, Version: pkceStorageVersion}) + return err +} + +func (a *pkceStorage) GetPKCERequestSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { + session, _, err := a.getSession(ctx, signature) + + if err != nil { + return nil, err + } + + return session.Request, err +} + +func (a *pkceStorage) DeletePKCERequestSession(ctx context.Context, signature string) error { + return a.storage.Delete(ctx, signature) +} + +func (a *pkceStorage) getSession(ctx context.Context, signature string) (*session, string, error) { + session := newValidEmptyPKCESession() + rv, err := a.storage.Get(ctx, signature, session) + + if errors.IsNotFound(err) { + return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error()) + } + + if err != nil { + return nil, "", fmt.Errorf("failed to get pkce session for %s: %w", signature, err) + } + + if version := session.Version; version != pkceStorageVersion { + return nil, "", fmt.Errorf("%w: pkce session for %s has version %s instead of %s", + ErrInvalidPKCERequestVersion, signature, version, pkceStorageVersion) + } + + if session.Request.ID == "" { + return nil, "", fmt.Errorf("malformed pkce session for %s: %w", signature, ErrInvalidPKCERequestData) + } + + return session, rv, nil +} + +func newValidEmptyPKCESession() *session { + return &session{ + Request: &fosite.Request{ + Client: &fosite.DefaultOpenIDConnectClient{}, + Session: &openid.DefaultSession{}, + }, + } +} diff --git a/internal/fositestorage/pkce/pkce_test.go b/internal/fositestorage/pkce/pkce_test.go new file mode 100644 index 00000000..80b2d9dd --- /dev/null +++ b/internal/fositestorage/pkce/pkce_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package pkce + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + coretesting "k8s.io/client-go/testing" +) + +const namespace = "test-ns" + +func TestPKCEStorage(t *testing.T) { + ctx := context.Background() + secretsGVR := schema.GroupVersionResource{ + Group: "", + Version: "v1", + Resource: "secrets", + } + + wantActions := []coretesting.Action{ + coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "pkce", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/pkce", + }), + coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"), + coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"), + } + + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + ID: "abcd-1", + RequestedAt: time.Time{}, + Client: &fosite.DefaultOpenIDConnectClient{ + DefaultClient: &fosite.DefaultClient{ + ID: "pinny", + Secret: nil, + RedirectURIs: nil, + GrantTypes: nil, + ResponseTypes: nil, + Scopes: nil, + Audience: nil, + Public: true, + }, + JSONWebKeysURI: "where", + JSONWebKeys: nil, + TokenEndpointAuthMethod: "something", + RequestURIs: nil, + RequestObjectSigningAlgorithm: "", + TokenEndpointAuthSigningAlgorithm: "", + }, + RequestedScope: nil, + GrantedScope: nil, + Form: url.Values{"key": []string{"val"}}, + Session: &openid.DefaultSession{ + Claims: nil, + Headers: nil, + ExpiresAt: nil, + Username: "snorlax", + Subject: "panda", + }, + RequestedAudience: nil, + GrantedAudience: nil, + } + err := storage.CreatePKCERequestSession(ctx, "fancy-signature", request) + require.NoError(t, err) + + newRequest, err := storage.GetPKCERequestSession(ctx, "fancy-signature", nil) + require.NoError(t, err) + require.Equal(t, request, newRequest) + + err = storage.DeletePKCERequestSession(ctx, "fancy-signature") + require.NoError(t, err) + + require.Equal(t, wantActions, client.Actions()) +} + +func TestGetNotFound(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + _, notFoundErr := storage.GetPKCERequestSession(ctx, "non-existent-signature", nil) + require.EqualError(t, notFoundErr, "not_found") + require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound)) +} + +func TestWrongVersion(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "pkce", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/pkce", + } + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil) + + require.EqualError(t, err, "pkce request data has wrong version: pkce session for fancy-signature has version not-the-right-version instead of 1") +} + +func TestNilSessionRequest(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4", + ResourceVersion: "", + Labels: map[string]string{ + "storage.pinniped.dev": "pkce", + }, + }, + Data: map[string][]byte{ + "pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`), + "pinniped-storage-version": []byte("1"), + }, + Type: "storage.pinniped.dev/pkce", + } + + _, err := secrets.Create(ctx, secret, metav1.CreateOptions{}) + require.NoError(t, err) + + _, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil) + require.EqualError(t, err, "malformed pkce session for fancy-signature: pkce request data must be present") +} + +func TestCreateWithNilRequester(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", nil) + require.EqualError(t, err, "requester must be of type fosite.Request") +} + +func TestCreateWithWrongRequesterDataTypes(t *testing.T) { + ctx := context.Background() + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets(namespace) + storage := New(secrets) + + request := &fosite.Request{ + Session: nil, + Client: &fosite.DefaultOpenIDConnectClient{}, + } + err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's session must be of type openid.DefaultSession") + + request = &fosite.Request{ + Session: &openid.DefaultSession{}, + Client: nil, + } + err = storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request) + require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient") +} diff --git a/internal/mocks/mockupstreamoidcidentityprovider/generate.go b/internal/mocks/mockupstreamoidcidentityprovider/generate.go new file mode 100644 index 00000000..cb9c46df --- /dev/null +++ b/internal/mocks/mockupstreamoidcidentityprovider/generate.go @@ -0,0 +1,6 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package mockupstreamoidcidentityprovider + +//go:generate go run -v github.com/golang/mock/mockgen -destination=mockupstreamoidcidentityprovider.go -package=mockupstreamoidcidentityprovider -copyright_file=../../../hack/header.txt go.pinniped.dev/internal/oidc/provider UpstreamOIDCIdentityProviderI diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go new file mode 100644 index 00000000..93085f4b --- /dev/null +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -0,0 +1,159 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: go.pinniped.dev/internal/oidc/provider (interfaces: UpstreamOIDCIdentityProviderI) + +// Package mockupstreamoidcidentityprovider is a generated GoMock package. +package mockupstreamoidcidentityprovider + +import ( + context "context" + gomock "github.com/golang/mock/gomock" + nonce "go.pinniped.dev/pkg/oidcclient/nonce" + oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes" + pkce "go.pinniped.dev/pkg/oidcclient/pkce" + oauth2 "golang.org/x/oauth2" + url "net/url" + reflect "reflect" +) + +// MockUpstreamOIDCIdentityProviderI is a mock of UpstreamOIDCIdentityProviderI interface +type MockUpstreamOIDCIdentityProviderI struct { + ctrl *gomock.Controller + recorder *MockUpstreamOIDCIdentityProviderIMockRecorder +} + +// MockUpstreamOIDCIdentityProviderIMockRecorder is the mock recorder for MockUpstreamOIDCIdentityProviderI +type MockUpstreamOIDCIdentityProviderIMockRecorder struct { + mock *MockUpstreamOIDCIdentityProviderI +} + +// NewMockUpstreamOIDCIdentityProviderI creates a new mock instance +func NewMockUpstreamOIDCIdentityProviderI(ctrl *gomock.Controller) *MockUpstreamOIDCIdentityProviderI { + mock := &MockUpstreamOIDCIdentityProviderI{ctrl: ctrl} + mock.recorder = &MockUpstreamOIDCIdentityProviderIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityProviderIMockRecorder { + return m.recorder +} + +// ExchangeAuthcodeAndValidateTokens mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(oidctypes.Token) + ret1, _ := ret[1].(map[string]interface{}) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3, arg4) +} + +// GetAuthorizationURL mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetAuthorizationURL() *url.URL { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAuthorizationURL") + ret0, _ := ret[0].(*url.URL) + return ret0 +} + +// GetAuthorizationURL indicates an expected call of GetAuthorizationURL +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetAuthorizationURL() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationURL", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetAuthorizationURL)) +} + +// GetClientID mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetClientID() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClientID") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetClientID indicates an expected call of GetClientID +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetClientID() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientID", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetClientID)) +} + +// GetGroupsClaim mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetGroupsClaim() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupsClaim") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetGroupsClaim indicates an expected call of GetGroupsClaim +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetGroupsClaim() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetGroupsClaim)) +} + +// GetName mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetName() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetName") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetName indicates an expected call of GetName +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetName() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetName)) +} + +// GetScopes mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetScopes() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetScopes") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetScopes indicates an expected call of GetScopes +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetScopes() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetScopes)) +} + +// GetUsernameClaim mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) GetUsernameClaim() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUsernameClaim") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetUsernameClaim indicates an expected call of GetUsernameClaim +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetUsernameClaim() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsernameClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetUsernameClaim)) +} + +// ValidateToken mocks base method +func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2) + ret0, _ := ret[0].(oidctypes.Token) + ret1, _ := ret[1].(map[string]interface{}) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// ValidateToken indicates an expected call of ValidateToken +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateToken(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateToken), arg0, arg1, arg2) +} diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index ec752cf3..f3a305f1 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -9,14 +9,13 @@ import ( "net/http" "time" - "github.com/gorilla/securecookie" - "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" "golang.org/x/oauth2" "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/plog" @@ -24,42 +23,15 @@ import ( "go.pinniped.dev/pkg/oidcclient/pkce" ) -const ( - // Just in case we need to make a breaking change to the format of the upstream state param, - // we are including a format version number. This gives the opportunity for a future version of Pinniped - // to have the consumer of this format decide to reject versions that it doesn't understand. - upstreamStateParamFormatVersion = "1" - - // The `name` passed to the encoder for encoding the upstream state param value. This name is short - // because it will be encoded into the upstream state param value and we're trying to keep that small. - upstreamStateParamEncodingName = "s" - - // The name of the browser cookie which shall hold our CSRF value. - // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes - csrfCookieName = "__Host-pinniped-csrf" - - // The `name` passed to the encoder for encoding and decoding the CSRF cookie contents. - csrfCookieEncodingName = "csrf" -) - -type IDPListGetter interface { - GetIDPList() []provider.UpstreamOIDCIdentityProvider -} - -// This is the encoding side of the securecookie.Codec interface. -type Encoder interface { - Encode(name string, value interface{}) (string, error) -} - func NewHandler( - issuer string, - idpListGetter IDPListGetter, + downstreamIssuer string, + idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, generateCSRF func() (csrftoken.CSRFToken, error), generatePKCE func() (pkce.Code, error), generateNonce func() (nonce.Nonce, error), - upstreamStateEncoder Encoder, - cookieCodec securecookie.Codec, + upstreamStateEncoder oidc.Encoder, + cookieCodec oidc.Codec, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodPost && r.Method != http.MethodGet { @@ -69,11 +41,7 @@ func NewHandler( return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) } - csrfFromCookie, err := readCSRFCookie(r, cookieCodec) - if err != nil { - plog.InfoErr("error reading CSRF cookie", err) - return err - } + csrfFromCookie := readCSRFCookie(r, cookieCodec) authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r) if err != nil { @@ -116,15 +84,22 @@ func NewHandler( } upstreamOAuthConfig := oauth2.Config{ - ClientID: upstreamIDP.ClientID, + ClientID: upstreamIDP.GetClientID(), Endpoint: oauth2.Endpoint{ - AuthURL: upstreamIDP.AuthorizationURL.String(), + AuthURL: upstreamIDP.GetAuthorizationURL().String(), }, - RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.Name), - Scopes: upstreamIDP.Scopes, + RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer), + Scopes: upstreamIDP.GetScopes(), } - encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) + encodedStateParamValue, err := upstreamStateParam( + authorizeRequester, + upstreamIDP.GetName(), + nonceValue, + csrfValue, + pkceValue, + upstreamStateEncoder, + ) if err != nil { plog.Error("authorize upstream state param error", err) return err @@ -154,20 +129,23 @@ func NewHandler( }) } -func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { - receivedCSRFCookie, err := r.Cookie(csrfCookieName) +func readCSRFCookie(r *http.Request, codec oidc.Codec) csrftoken.CSRFToken { + receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found - return "", nil + return "" } var csrfFromCookie csrftoken.CSRFToken - err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + err = codec.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) if err != nil { - return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err) + // We can ignore any errors and just make a new cookie. Hopefully this will + // make the user experience better if, for example, the server rotated + // cookie signing keys and then a user submitted a very old cookie. + return "" } - return csrfFromCookie, nil + return csrfFromCookie } func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { @@ -178,7 +156,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { } } -func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { +func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (provider.UpstreamOIDCIdentityProviderI, error) { allUpstreamIDPs := idpListGetter.GetIDPList() if len(allUpstreamIDPs) == 0 { return nil, httperr.New( @@ -191,7 +169,7 @@ func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdent "Too many upstream providers are configured (support for multiple upstreams is not yet implemented)", ) } - return &allUpstreamIDPs[0], nil + return allUpstreamIDPs[0], nil } func generateValues( @@ -214,48 +192,42 @@ func generateValues( return csrfValue, nonceValue, pkceValue, nil } -// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param. -type upstreamStateParamData struct { - AuthParams string `json:"p"` - Nonce nonce.Nonce `json:"n"` - CSRFToken csrftoken.CSRFToken `json:"c"` - PKCECode pkce.Code `json:"k"` - StateParamFormatVersion string `json:"v"` -} - func upstreamStateParam( authorizeRequester fosite.AuthorizeRequester, + upstreamName string, nonceValue nonce.Nonce, csrfValue csrftoken.CSRFToken, pkceValue pkce.Code, - encoder Encoder, + encoder oidc.Encoder, ) (string, error) { - stateParamData := upstreamStateParamData{ - AuthParams: authorizeRequester.GetRequestForm().Encode(), - Nonce: nonceValue, - CSRFToken: csrfValue, - PKCECode: pkceValue, - StateParamFormatVersion: upstreamStateParamFormatVersion, + stateParamData := oidc.UpstreamStateParamData{ + AuthParams: authorizeRequester.GetRequestForm().Encode(), + UpstreamName: upstreamName, + Nonce: nonceValue, + CSRFToken: csrfValue, + PKCECode: pkceValue, + FormatVersion: oidc.UpstreamStateParamFormatVersion, } - encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData) + encodedStateParamValue, err := encoder.Encode(oidc.UpstreamStateParamEncodingName, stateParamData) if err != nil { return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err) } return encodedStateParamValue, nil } -func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { - encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue) +func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec oidc.Codec) error { + encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue) if err != nil { return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) } http.SetCookie(w, &http.Cookie{ - Name: csrfCookieName, + Name: oidc.CSRFCookieName, Value: encodedCSRFValue, HttpOnly: true, SameSite: http.SameSiteStrictMode, Secure: true, + Path: "/", }) return nil diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 3a522495..73c05248 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -21,6 +21,7 @@ import ( "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/pkce" @@ -28,6 +29,7 @@ import ( func TestAuthorizationEndpoint(t *testing.T) { const ( + downstreamIssuer = "https://my-downstream-issuer.com/some-path" downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback" ) @@ -113,21 +115,19 @@ func TestAuthorizationEndpoint(t *testing.T) { upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") require.NoError(t, err) - upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ + upstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: "some-idp", ClientID: "some-client-id", AuthorizationURL: *upstreamAuthURL, Scopes: []string{"scope1", "scope2"}, } - issuer := "https://my-issuer.com/some-path" - - // Configure fosite the same way that the production code would, except use in-memory storage. + // Configure fosite the same way that the production code would, using NullStorage to turn off storage. oauthStore := oidc.NullStorage{} hmacSecret := []byte("some secret - must have at least 32 bytes") var signingKeyIsUnused *ecdsa.PrivateKey require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") - oauthHelper := oidc.FositeOauth2Helper(issuer, oauthStore, hmacSecret, signingKeyIsUnused) + oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret, signingKeyIsUnused) happyCSRF := "test-csrf" happyPKCE := "test-pkce" @@ -206,14 +206,19 @@ func TestAuthorizationEndpoint(t *testing.T) { return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides)) } - expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string { + expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride, upstreamNameOverride string) string { csrf := happyCSRF if csrfValueOverride != "" { csrf = csrfValueOverride } + upstreamName := upstreamOIDCIdentityProvider.Name + if upstreamNameOverride != "" { + upstreamName = upstreamNameOverride + } encoded, err := happyStateEncoder.Encode("s", - expectedUpstreamStateParamFormat{ + oidctestutil.ExpectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), + U: upstreamName, N: happyNonce, C: csrf, K: happyPKCE, @@ -234,7 +239,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": issuer + "/callback/some-idp", + "redirect_uri": downstreamIssuer + "/callback", }) } @@ -250,8 +255,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF func() (csrftoken.CSRFToken, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) - stateEncoder securecookie.Codec - cookieEncoder securecookie.Codec + stateEncoder oidc.Codec + cookieEncoder oidc.Codec method string path string contentType string @@ -271,8 +276,8 @@ func TestAuthorizationEndpoint(t *testing.T) { tests := []testCase{ { name: "happy path using GET without a CSRF cookie", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -283,14 +288,14 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, { name: "happy path using GET with a CSRF cookie", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -298,17 +303,17 @@ func TestAuthorizationEndpoint(t *testing.T) { cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, - csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue, + csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + " ", wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, { name: "happy path using POST", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -322,13 +327,33 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "", wantBodyString: "", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, }, + { + name: "error while decoding CSRF cookie just generates a new cookie and succeeds as usual", + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, + method: http.MethodGet, + path: happyGetRequestPath, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + // Generated a new CSRF cookie and set it in the response. + wantCSRFValueInCookieHeader: happyCSRF, + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), + wantUpstreamStateParamInLocationHeader: true, + wantBodyStringWithLocationInHref: true, + }, { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -343,14 +368,14 @@ func TestAuthorizationEndpoint(t *testing.T) { wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client - }, "")), + }, "", "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, { name: "downstream redirect uri does not match what is configured for client", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -366,8 +391,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream client does not exist", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -381,8 +406,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "response type is unsupported", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -397,8 +422,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream scopes do not match what is configured for client", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -413,8 +438,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing response type in request", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -429,8 +454,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing client id in request", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -444,8 +469,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -460,8 +485,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -476,8 +501,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -492,8 +517,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -510,8 +535,8 @@ func TestAuthorizationEndpoint(t *testing.T) { // This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -526,8 +551,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC validations are skipped when the openid scope was not requested", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -540,15 +565,15 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam( - map[string]string{"prompt": "none login", "scope": "email"}, "", + map[string]string{"prompt": "none login", "scope": "email"}, "", "", )), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, { name: "state does not have enough entropy", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -563,8 +588,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while encoding upstream state param", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -578,8 +603,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while encoding CSRF cookie value for new cookie", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -593,8 +618,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating CSRF token", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -608,8 +633,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating nonce", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, @@ -623,8 +648,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating PKCE", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, @@ -636,26 +661,10 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "text/plain; charset=utf-8", wantBodyString: "Internal Server Error: error generating PKCE param\n", }, - { - name: "error while decoding CSRF cookie", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), - generateCSRF: happyCSRFGenerator, - generatePKCE: happyPKCEGenerator, - generateNonce: happyNonceGenerator, - stateEncoder: happyStateEncoder, - cookieEncoder: happyCookieEncoder, - method: http.MethodGet, - path: happyGetRequestPath, - csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", - wantStatus: http.StatusUnprocessableEntity, - wantContentType: "text/plain; charset=utf-8", - wantBodyString: "Unprocessable Entity: error reading CSRF cookie\n", - }, { name: "no upstream providers are configured", - issuer: issuer, - idpListGetter: newIDPListGetter(), // empty + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(), // empty method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -664,8 +673,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "too many upstream providers are configured", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -674,8 +683,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "PUT is a bad method", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -684,8 +693,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "PATCH is a bad method", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -694,8 +703,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "DELETE is a bad method", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -712,6 +721,8 @@ func TestAuthorizationEndpoint(t *testing.T) { } rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) + t.Logf("response: %#v", rsp) + t.Logf("response body: %q", rsp.Body.String()) require.Equal(t, test.wantStatus, rsp.Code) requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) @@ -742,7 +753,7 @@ func TestAuthorizationEndpoint(t *testing.T) { if test.wantCSRFValueInCookieHeader != "" { require.Len(t, rsp.Header().Values("Set-Cookie"), 1) actualCookie := rsp.Header().Get("Set-Cookie") - regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); HttpOnly; Secure; SameSite=Strict") + regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); Path=/; HttpOnly; Secure; SameSite=Strict") submatches := regex.FindStringSubmatch(actualCookie) require.Len(t, submatches, 2) captured := submatches[1] @@ -772,13 +783,13 @@ func TestAuthorizationEndpoint(t *testing.T) { runOneTestCase(t, test, subject) // Call the setter to change the upstream IDP settings. - newProviderSettings := provider.UpstreamOIDCIdentityProvider{ + newProviderSettings := oidctestutil.TestUpstreamOIDCIdentityProvider{ Name: "some-other-idp", ClientID: "some-other-client-id", AuthorizationURL: *upstreamAuthURL, Scopes: []string{"other-scope1", "other-scope2"}, } - test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{newProviderSettings}) + test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(&newProviderSettings)}) // Update the expectations of the test case to match the new upstream IDP settings. test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(), @@ -787,11 +798,11 @@ func TestAuthorizationEndpoint(t *testing.T) { "access_type": "offline", "scope": "other-scope1 other-scope2", "client_id": "some-other-client-id", - "state": expectedUpstreamStateParam(nil, ""), + "state": expectedUpstreamStateParam(nil, "", newProviderSettings.Name), "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": issuer + "/callback/some-other-idp", + "redirect_uri": downstreamIssuer + "/callback", }, ) test.wantBodyString = fmt.Sprintf(`Found.%s`, @@ -807,20 +818,8 @@ func TestAuthorizationEndpoint(t *testing.T) { }) } -// Declare a separate type from the production code to ensure that the state param's contents was serialized -// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of -// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality -// assertions about the redirect URL in this test. -type expectedUpstreamStateParamFormat struct { - P string `json:"p"` - N string `json:"n"` - C string `json:"c"` - K string `json:"k"` - V string `json:"v"` -} - type errorReturningEncoder struct { - securecookie.Codec + oidc.Codec } func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) { @@ -843,7 +842,7 @@ func requireEqualContentType(t *testing.T, actual string, expected string) { require.Equal(t, actualContentTypeParams, expectedContentTypeParams) } -func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) { +func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) { t.Helper() actualLocationURL, err := url.Parse(actualURL) require.NoError(t, err) @@ -852,13 +851,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL expectedQueryStateParam := expectedLocationURL.Query().Get("state") require.NotEmpty(t, expectedQueryStateParam) - var expectedDecodedStateParam expectedUpstreamStateParamFormat + var expectedDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) require.NoError(t, err) actualQueryStateParam := actualLocationURL.Query().Get("state") require.NotEmpty(t, actualQueryStateParam) - var actualDecodedStateParam expectedUpstreamStateParamFormat + var actualDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) require.NoError(t, err) @@ -871,10 +870,20 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore require.NoError(t, err) expectedLocationURL, err := url.Parse(expectedURL) require.NoError(t, err) - require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme) - require.Equal(t, expectedLocationURL.User, actualLocationURL.User) - require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host) - require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path) + require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme, + "schemes were not equal: expected %s but got %s", expectedURL, actualURL, + ) + require.Equal(t, expectedLocationURL.User, actualLocationURL.User, + "users were not equal: expected %s but got %s", expectedURL, actualURL, + ) + + require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host, + "hosts were not equal: expected %s but got %s", expectedURL, actualURL, + ) + + require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path, + "paths were not equal: expected %s but got %s", expectedURL, actualURL, + ) expectedLocationQuery := expectedLocationURL.Query() actualLocationQuery := actualLocationURL.Query() @@ -886,9 +895,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore } require.Equal(t, expectedLocationQuery, actualLocationQuery) } - -func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { - idpProvider := provider.NewDynamicUpstreamIDPProvider() - idpProvider.SetIDPList(upstreamOIDCIdentityProviders) - return idpProvider -} diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go new file mode 100644 index 00000000..4add765e --- /dev/null +++ b/internal/oidc/callback/callback_handler.go @@ -0,0 +1,306 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package callback provides a handler for the OIDC callback endpoint. +package callback + +import ( + "crypto/subtle" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/token/jwt" + + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/csrftoken" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/plog" +) + +const ( + // The name of the issuer claim specified in the OIDC spec. + idTokenIssuerClaim = "iss" + + // The name of the subject claim specified in the OIDC spec. + idTokenSubjectClaim = "sub" + + // defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC + // ID token if the upstream OIDC IDP did not tell us to use another claim. + defaultUpstreamUsernameClaim = idTokenSubjectClaim + + // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token + // information. + downstreamGroupsClaim = "groups" +) + +func NewHandler( + idpListGetter oidc.IDPListGetter, + oauthHelper fosite.OAuth2Provider, + stateDecoder, cookieDecoder oidc.Decoder, + redirectURI string, +) http.Handler { + return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + state, err := validateRequest(r, stateDecoder, cookieDecoder) + if err != nil { + return err + } + + upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter) + if upstreamIDPConfig == nil { + plog.Warning("upstream provider not found") + return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") + } + + downstreamAuthParams, err := url.ParseQuery(state.AuthParams) + if err != nil { + plog.Error("error reading state downstream auth params", err) + return httperr.New(http.StatusBadRequest, "error reading state downstream auth params") + } + + // Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest(). + reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams} + authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest) + if err != nil { + plog.Error("error using state downstream auth params", err) + return httperr.New(http.StatusBadRequest, "error using state downstream auth params") + } + + // Grant the openid scope only if it was requested. + grantOpenIDScopeIfRequested(authorizeRequester) + + _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( + r.Context(), + authcode(r), + state.PKCECode, + state.Nonce, + redirectURI, + ) + if err != nil { + plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName()) + return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") + } + + username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + if err != nil { + return err + } + + groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + if err != nil { + return err + } + + openIDSession := makeDownstreamSession(username, groups) + authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession) + if err != nil { + plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName()) + return httperr.Wrap(http.StatusInternalServerError, "error while generating and saving authcode", err) + } + + oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder) + + return nil + }) +} + +func authcode(r *http.Request) string { + return r.FormValue("code") +} + +func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { + if r.Method != http.MethodGet { + return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) + } + + csrfValue, err := readCSRFCookie(r, cookieDecoder) + if err != nil { + plog.InfoErr("error reading CSRF cookie", err) + return nil, err + } + + if authcode(r) == "" { + plog.Info("code param not found") + return nil, httperr.New(http.StatusBadRequest, "code param not found") + } + + if r.FormValue("state") == "" { + plog.Info("state param not found") + return nil, httperr.New(http.StatusBadRequest, "state param not found") + } + + state, err := readState(r, stateDecoder) + if err != nil { + plog.InfoErr("error reading state", err) + return nil, err + } + + if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 { + plog.InfoErr("CSRF value does not match", err) + return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) + } + + return state, nil +} + +func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { + for _, p := range idpListGetter.GetIDPList() { + if p.GetName() == upstreamName { + return p + } + } + return nil +} + +func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) { + receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) + if err != nil { + // Error means that the cookie was not found + return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) + } + + var csrfFromCookie csrftoken.CSRFToken + err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + if err != nil { + return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) + } + + return csrfFromCookie, nil +} + +func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { + var state oidc.UpstreamStateParamData + if err := stateDecoder.Decode( + oidc.UpstreamStateParamEncodingName, + r.FormValue("state"), + &state, + ); err != nil { + return nil, httperr.New(http.StatusBadRequest, "error reading state") + } + + if state.FormatVersion != oidc.UpstreamStateParamFormatVersion { + return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") + } + + return &state, nil +} + +func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { + for _, scope := range authorizeRequester.GetRequestedScopes() { + if scope == "openid" { + authorizeRequester.GrantScope(scope) + } + } +} + +func getUsernameFromUpstreamIDToken( + upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, + idTokenClaims map[string]interface{}, +) (string, error) { + usernameClaim := upstreamIDPConfig.GetUsernameClaim() + + user := "" + if usernameClaim == "" { + // The spec says the "sub" claim is only unique per issuer, so by default when there is + // no specific username claim configured we will prepend the issuer string to make it globally unique. + upstreamIssuer := idTokenClaims[idTokenIssuerClaim] + if upstreamIssuer == "" { + plog.Warning( + "issuer claim in upstream ID token missing", + "upstreamName", upstreamIDPConfig.GetName(), + "issClaim", upstreamIssuer, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing") + } + upstreamIssuerAsString, ok := upstreamIssuer.(string) + if !ok { + plog.Warning( + "issuer claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "issClaim", upstreamIssuer, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") + } + user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim) + usernameClaim = defaultUpstreamUsernameClaim + } + + usernameAsInterface, ok := idTokenClaims[usernameClaim] + if !ok { + plog.Warning( + "no username claim in upstream ID token", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), + "usernameClaim", usernameClaim, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token") + } + + username, ok := usernameAsInterface.(string) + if !ok { + plog.Warning( + "username claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(), + "usernameClaim", usernameClaim, + ) + return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format") + } + + return fmt.Sprintf("%s%s", user, username), nil +} + +func getGroupsFromUpstreamIDToken( + upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, + idTokenClaims map[string]interface{}, +) ([]string, error) { + groupsClaim := upstreamIDPConfig.GetGroupsClaim() + if groupsClaim == "" { + return nil, nil + } + + groupsAsInterface, ok := idTokenClaims[groupsClaim] + if !ok { + plog.Warning( + "no groups claim in upstream ID token", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), + "groupsClaim", groupsClaim, + ) + return nil, httperr.New(http.StatusUnprocessableEntity, "no groups claim in upstream ID token") + } + + groups, ok := groupsAsInterface.([]string) + if !ok { + plog.Warning( + "groups claim in upstream ID token has invalid format", + "upstreamName", upstreamIDPConfig.GetName(), + "configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(), + "groupsClaim", groupsClaim, + ) + return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format") + } + + return groups, nil +} + +func makeDownstreamSession(username string, groups []string) *openid.DefaultSession { + now := time.Now().UTC() + openIDSession := &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: username, + RequestedAt: now, + AuthTime: now, + }, + } + if groups != nil { + openIDSession.Claims.Extra = map[string]interface{}{ + downstreamGroupsClaim: groups, + } + } + return openIDSession +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go new file mode 100644 index 00000000..ead11693 --- /dev/null +++ b/internal/oidc/callback/callback_handler_test.go @@ -0,0 +1,859 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package callback + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + "time" + + "github.com/gorilla/securecookie" + "github.com/ory/fosite" + "github.com/ory/fosite/handler/openid" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/kubernetes/fake" + kubetesting "k8s.io/client-go/testing" + + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/oidctestutil" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" +) + +const ( + happyUpstreamIDPName = "upstream-idp-name" + + upstreamIssuer = "https://my-upstream-issuer.com" + upstreamSubject = "abc123-some-guid" + upstreamUsername = "test-pinniped-username" + + upstreamUsernameClaim = "the-user-claim" + upstreamGroupsClaim = "the-groups-claim" + + happyUpstreamAuthcode = "upstream-auth-code" + + happyUpstreamRedirectURI = "https://example.com/callback" + + happyDownstreamState = "some-downstream-state-with-at-least-32-bytes" + happyDownstreamCSRF = "test-csrf" + happyDownstreamPKCE = "test-pkce" + happyDownstreamNonce = "test-nonce" + happyDownstreamStateVersion = "1" + + downstreamIssuer = "https://my-downstream-issuer.com/path" + downstreamRedirectURI = "http://127.0.0.1/callback" + downstreamClientID = "pinniped-cli" + downstreamNonce = "some-nonce-value" + downstreamPKCEChallenge = "some-challenge" + downstreamPKCEChallengeMethod = "S256" + + timeComparisonFudgeFactor = time.Second * 15 +) + +var ( + upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} + happyDownstreamScopesRequested = []string{"openid", "profile", "email"} + + happyDownstreamRequestParamsQuery = url.Values{ + "response_type": []string{"code"}, + "scope": []string{strings.Join(happyDownstreamScopesRequested, " ")}, + "client_id": []string{downstreamClientID}, + "state": []string{happyDownstreamState}, + "nonce": []string{downstreamNonce}, + "code_challenge": []string{downstreamPKCEChallenge}, + "code_challenge_method": []string{downstreamPKCEChallengeMethod}, + "redirect_uri": []string{downstreamRedirectURI}, + } + happyDownstreamRequestParams = happyDownstreamRequestParamsQuery.Encode() +) + +func TestCallbackEndpoint(t *testing.T) { + otherUpstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: "other-upstream-idp-name", + ClientID: "other-some-client-id", + Scopes: []string{"other-scope1", "other-scope2"}, + } + + var stateEncoderHashKey = []byte("fake-hash-secret") + var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + var cookieEncoderHashKey = []byte("fake-hash-secret2") + var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + var happyStateCodec = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateCodec.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) + + happyState := happyUpstreamStateParam().Build(t, happyStateCodec) + + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) + require.NoError(t, err) + happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + + happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{ + Authcode: happyUpstreamAuthcode, + PKCECodeVerifier: pkce.Code(happyDownstreamPKCE), + ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), + RedirectURI: happyUpstreamRedirectURI, + } + + // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it + happyDownstreamRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState + + tests := []struct { + name string + + idp oidctestutil.TestUpstreamOIDCIdentityProvider + method string + path string + csrfCookie string + + wantStatus int + wantBody string + wantRedirectLocationRegexp string + wantGrantedOpenidScope bool + wantDownstreamIDTokenSubject string + wantDownstreamIDTokenGroups []string + wantDownstreamRequestedScopes []string + wantDownstreamNonce string + wantDownstreamPKCEChallenge string + wantDownstreamPKCEChallengeMethod string + + wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs + }{ + { + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, + wantGrantedOpenidScope: true, + wantBody: "", + wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups", + idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, + wantGrantedOpenidScope: true, + wantBody: "", + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject, + wantDownstreamIDTokenGroups: nil, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for", + idp: happyUpstream().WithUsernameClaim("sub").Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, + wantGrantedOpenidScope: true, + wantBody: "", + wantDownstreamIDTokenSubject: upstreamSubject, + wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + + // Pre-upstream-exchange verification + { + name: "PUT method is invalid", + method: http.MethodPut, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: PUT (try GET)\n", + }, + { + name: "POST method is invalid", + method: http.MethodPost, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: POST (try GET)\n", + }, + { + name: "PATCH method is invalid", + method: http.MethodPatch, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: PATCH (try GET)\n", + }, + { + name: "DELETE method is invalid", + method: http.MethodDelete, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: DELETE (try GET)\n", + }, + { + name: "code param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithoutCode().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: code param not found\n", + }, + { + name: "state param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state\n", + }, + { + // This shouldn't happen in practice because the authorize endpoint should have already run the same + // validations, but we would like to test the error handling in this endpoint anyway. + name: "state param contains authorization request params which fail validation", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam(). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()). + Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantStatus: http.StatusInternalServerError, + wantBody: "Internal Server Error: error while generating and saving authcode\n", + }, + { + name: "state's internal version does not match what we want", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyUpstreamStateParam().WithStateVersion("wrong-state-version").Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: state format version is invalid\n", + }, + { + name: "state's downstream auth params element is invalid", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyUpstreamStateParam(). + WithAuthorizeRequestParams("the following is an invalid url encoding token, and therefore this is an invalid param: %z"). + Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state downstream auth params\n", + }, + { + name: "state's downstream auth params are missing required value (e.g., client_id)", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam(). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"client_id": ""}).Encode()). + Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error using state downstream auth params\n", + }, + { + name: "state's downstream auth params does not contain openid scope", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath(). + WithState( + happyUpstreamStateParam(). + WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()). + Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + wantDownstreamIDTokenSubject: upstreamUsername, + wantDownstreamRequestedScopes: []string{"profile", "email"}, + wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "the UpstreamOIDCProvider CRD has been deleted", + idp: otherUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: upstream provider not found\n", + }, + { + name: "the CSRF cookie does not exist on request", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF cookie is missing\n", + }, + { + name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "cookie csrf value does not match state csrf value", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF value does not match\n", + }, + + // Upstream exchange + { + name: "upstream auth code exchange fails", + idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadGateway, + wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token does not contain requested username claim", + idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token does not contain requested groups claim", + idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: no groups claim in upstream ID token\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token contains username claim with weird format", + idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token does not contain iss claim when using default username claim config", + idp: happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token has an non-string iss claim when using default username claim config", + idp: happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream ID token contains groups claim with weird format", + idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + } + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + client := fake.NewSimpleClientset() + secrets := client.CoreV1().Secrets("some-namespace") + + // Configure fosite the same way that the production code would. + // Inject this into our test subject at the last second so we get a fresh storage for every test. + oauthStore := oidc.NewKubeStorage(secrets) + hmacSecret := []byte("some secret - must have at least 32 bytes") + require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") + oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) + + idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) + subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) + req := httptest.NewRequest(test.method, test.path, nil) + if test.csrfCookie != "" { + req.Header.Set("Cookie", test.csrfCookie) + } + rsp := httptest.NewRecorder() + subject.ServeHTTP(rsp, req) + t.Logf("response: %#v", rsp) + t.Logf("response body: %q", rsp.Body.String()) + + if test.wantExchangeAndValidateTokensCall != nil { + require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + test.wantExchangeAndValidateTokensCall.Ctx = req.Context() + require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0)) + } else { + require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + } + + require.Equal(t, test.wantStatus, rsp.Code) + + if test.wantBody != "" { + require.Equal(t, test.wantBody, rsp.Body.String()) + } else { + require.Empty(t, rsp.Body.String()) + } + + if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test + // Assert that Location header matches regular expression. + require.Len(t, rsp.Header().Values("Location"), 1) + actualLocation := rsp.Header().Get("Location") + regex := regexp.MustCompile(test.wantRedirectLocationRegexp) + submatches := regex.FindStringSubmatch(actualLocation) + require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation) + capturedAuthCode := submatches[1] + + // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface + authcodeDataAndSignature := strings.Split(capturedAuthCode, ".") + require.Len(t, authcodeDataAndSignature, 2) + + // Several Secrets should have been created + expectedNumberOfCreatedSecrets := 2 + if test.wantGrantedOpenidScope { + expectedNumberOfCreatedSecrets++ + } + require.Len(t, client.Actions(), expectedNumberOfCreatedSecrets) + + actualSecretNames := []string{} + for i := range client.Actions() { + actualAction := client.Actions()[i].(kubetesting.CreateActionImpl) + require.Equal(t, "create", actualAction.GetVerb()) + require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource()) + actualSecret := actualAction.GetObject().(*corev1.Secret) + require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace + actualSecretNames = append(actualSecretNames, actualSecret.Name) + } + + // One authcode should have been stored. + requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-authcode-") + + storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage( + t, + oauthStore, + authcodeDataAndSignature[1], // Authcode store key is authcode signature + test.wantGrantedOpenidScope, + test.wantDownstreamIDTokenSubject, + test.wantDownstreamIDTokenGroups, + test.wantDownstreamRequestedScopes, + ) + + // One PKCE should have been stored. + requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-pkce-") + + validatePKCEStorage( + t, + oauthStore, + authcodeDataAndSignature[1], // PKCE store key is authcode signature + storedRequestFromAuthcode, + storedSessionFromAuthcode, + test.wantDownstreamPKCEChallenge, + test.wantDownstreamPKCEChallengeMethod, + ) + + // One IDSession should have been stored, if the downstream actually requested the "openid" scope + if test.wantGrantedOpenidScope { + requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-oidc") + + validateIDSessionStorage( + t, + oauthStore, + capturedAuthCode, // IDSession store key is full authcode + storedRequestFromAuthcode, + storedSessionFromAuthcode, + test.wantDownstreamNonce, + ) + } + } + }) + } +} + +type requestPath struct { + code, state *string +} + +func newRequestPath() *requestPath { + c := happyUpstreamAuthcode + s := "4321" + return &requestPath{ + code: &c, + state: &s, + } +} + +func (r *requestPath) WithCode(code string) *requestPath { + r.code = &code + return r +} + +func (r *requestPath) WithoutCode() *requestPath { + r.code = nil + return r +} + +func (r *requestPath) WithState(state string) *requestPath { + r.state = &state + return r +} + +func (r *requestPath) WithoutState() *requestPath { + r.state = nil + return r +} + +func (r *requestPath) String() string { + path := "/downstream-provider-name/callback?" + params := url.Values{} + if r.code != nil { + params.Add("code", *r.code) + } + if r.state != nil { + params.Add("state", *r.state) + } + return path + params.Encode() +} + +type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat + +func happyUpstreamStateParam() *upstreamStateParamBuilder { + return &upstreamStateParamBuilder{ + U: happyUpstreamIDPName, + P: happyDownstreamRequestParams, + N: happyDownstreamNonce, + C: happyDownstreamCSRF, + K: happyDownstreamPKCE, + V: happyDownstreamStateVersion, + } +} + +func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { + state, err := stateEncoder.Encode("s", b) + require.NoError(t, err) + return state +} + +func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder { + b.P = params + return b +} + +func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder { + b.N = nonce + return b +} + +func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder { + b.C = csrf + return b +} + +func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder { + b.K = pkce + return b +} + +func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder { + b.V = version + return b +} + +type upstreamOIDCIdentityProviderBuilder struct { + idToken map[string]interface{} + usernameClaim, groupsClaim string + authcodeExchangeErr error +} + +func happyUpstream() *upstreamOIDCIdentityProviderBuilder { + return &upstreamOIDCIdentityProviderBuilder{ + usernameClaim: upstreamUsernameClaim, + groupsClaim: upstreamGroupsClaim, + idToken: map[string]interface{}{ + "iss": upstreamIssuer, + "sub": upstreamSubject, + upstreamUsernameClaim: upstreamUsername, + upstreamGroupsClaim: upstreamGroupMembership, + "other-claim": "should be ignored", + }, + } +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(claim string) *upstreamOIDCIdentityProviderBuilder { + u.usernameClaim = claim + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder { + u.usernameClaim = "" + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder { + u.groupsClaim = "" + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *upstreamOIDCIdentityProviderBuilder { + u.idToken[name] = value + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder { + delete(u.idToken, claim) + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder { + u.authcodeExchangeErr = err + return u +} + +func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamOIDCIdentityProvider { + return oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + UsernameClaim: u.usernameClaim, + GroupsClaim: u.groupsClaim, + Scopes: []string{"scope1", "scope2"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + return oidctypes.Token{}, u.idToken, u.authcodeExchangeErr + }, + } +} + +func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values { + copied := url.Values{} + for key, value := range query { + copied[key] = value + } + for key, value := range modifications { + if value == "" { + copied.Del(key) + } else { + copied[key] = []string{value} + } + } + return copied +} + +func validateAuthcodeStorage( + t *testing.T, + oauthStore *oidc.KubeStorage, + storeKey string, + wantGrantedOpenidScope bool, + wantDownstreamIDTokenSubject string, + wantDownstreamIDTokenGroups []string, + wantDownstreamRequestedScopes []string, +) (*fosite.Request, *openid.DefaultSession) { + t.Helper() + + // Get the authcode session back from storage so we can require that it was stored correctly. + storedAuthorizeRequestFromAuthcode, err := oauthStore.GetAuthorizeCodeSession(context.Background(), storeKey, nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + storedRequestFromAuthcode, storedSessionFromAuthcode := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromAuthcode) + + // Check which scopes were granted. + if wantGrantedOpenidScope { + require.Contains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid") + } else { + require.NotContains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid") + } + + // Check all the other fields of the stored request. + require.NotEmpty(t, storedRequestFromAuthcode.ID) + require.Equal(t, downstreamClientID, storedRequestFromAuthcode.Client.GetID()) + require.ElementsMatch(t, wantDownstreamRequestedScopes, storedRequestFromAuthcode.RequestedScope) + require.Nil(t, storedRequestFromAuthcode.RequestedAudience) + require.Empty(t, storedRequestFromAuthcode.GrantedAudience) + require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequestFromAuthcode.Form) + testutil.RequireTimeInDelta(t, time.Now(), storedRequestFromAuthcode.RequestedAt, timeComparisonFudgeFactor) + + // We're not using these fields yet, so confirm that we did not set them (for now). + require.Empty(t, storedSessionFromAuthcode.Subject) + require.Empty(t, storedSessionFromAuthcode.Username) + require.Empty(t, storedSessionFromAuthcode.Headers) + + // The authcode that we are issuing should be good for the length of time that we declare in the fosite config. + testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*3), storedSessionFromAuthcode.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor) + require.Len(t, storedSessionFromAuthcode.ExpiresAt, 1) + + // Now confirm the ID token claims. + actualClaims := storedSessionFromAuthcode.Claims + + // Check the user's identity, which are put into the downstream ID token's subject and groups claims. + require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) + if wantDownstreamIDTokenGroups != nil { + require.Len(t, actualClaims.Extra, 1) + require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"]) + } else { + require.Empty(t, actualClaims.Extra) + require.NotContains(t, actualClaims.Extra, "groups") + } + + // Check the rest of the downstream ID token's claims. Fosite wants us to set these (in UTC time). + testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.RequestedAt, timeComparisonFudgeFactor) + testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.AuthTime, timeComparisonFudgeFactor) + requestedAtZone, _ := actualClaims.RequestedAt.Zone() + require.Equal(t, "UTC", requestedAtZone) + authTimeZone, _ := actualClaims.AuthTime.Zone() + require.Equal(t, "UTC", authTimeZone) + + // Fosite will set these fields for us in the token endpoint based on the store session + // information. Therefore, we assert that they are empty because we want the library to do the + // lifting for us. + require.Empty(t, actualClaims.Issuer) + require.Nil(t, actualClaims.Audience) + require.Empty(t, actualClaims.Nonce) + require.Zero(t, actualClaims.ExpiresAt) + require.Zero(t, actualClaims.IssuedAt) + + // These are not needed yet. + require.Empty(t, actualClaims.JTI) + require.Empty(t, actualClaims.CodeHash) + require.Empty(t, actualClaims.AccessTokenHash) + require.Empty(t, actualClaims.AuthenticationContextClassReference) + require.Empty(t, actualClaims.AuthenticationMethodsReference) + + return storedRequestFromAuthcode, storedSessionFromAuthcode +} + +func validatePKCEStorage( + t *testing.T, + oauthStore *oidc.KubeStorage, + storeKey string, + storedRequestFromAuthcode *fosite.Request, + storedSessionFromAuthcode *openid.DefaultSession, + wantDownstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod string, +) { + t.Helper() + + storedAuthorizeRequestFromPKCE, err := oauthStore.GetPKCERequestSession(context.Background(), storeKey, nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + storedRequestFromPKCE, storedSessionFromPKCE := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromPKCE) + + // The stored PKCE request should be the same as the stored authcode request. + require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromPKCE.ID) + require.Equal(t, storedSessionFromAuthcode, storedSessionFromPKCE) + + // The stored PKCE request should also contain the PKCE challenge that the downstream sent us. + require.Equal(t, wantDownstreamPKCEChallenge, storedRequestFromPKCE.Form.Get("code_challenge")) + require.Equal(t, wantDownstreamPKCEChallengeMethod, storedRequestFromPKCE.Form.Get("code_challenge_method")) +} + +func validateIDSessionStorage( + t *testing.T, + oauthStore *oidc.KubeStorage, + storeKey string, + storedRequestFromAuthcode *fosite.Request, + storedSessionFromAuthcode *openid.DefaultSession, + wantDownstreamNonce string, +) { + t.Helper() + + storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil) + require.NoError(t, err) + + // Check that storage returned the expected concrete data types. + storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession) + + // The stored IDSession request should be the same as the stored authcode request. + require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID) + require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession) + + // The stored IDSession request should also contain the nonce that the downstream sent us. + require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce")) +} + +func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requester) (*fosite.Request, *openid.DefaultSession) { + t.Helper() + + storedRequest, ok := storedAuthorizeRequest.(*fosite.Request) + require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest, &fosite.Request{}) + storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession) + require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest.GetSession(), &openid.DefaultSession{}) + + return storedRequest, storedSession +} + +func requireAnyStringHasPrefix(t *testing.T, stringList []string, prefix string) { + t.Helper() + + containsPrefix := false + for i := range stringList { + if strings.HasPrefix(stringList[i], prefix) { + containsPrefix = true + } + } + require.Truef(t, containsPrefix, "list %v did not contain any strings with prefix %s", stringList, prefix) +} diff --git a/internal/oidc/kube_storage.go b/internal/oidc/kube_storage.go new file mode 100644 index 00000000..405a6ade --- /dev/null +++ b/internal/oidc/kube_storage.go @@ -0,0 +1,120 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package oidc + +import ( + "context" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/handler/oauth2" + "github.com/ory/fosite/handler/openid" + fositepkce "github.com/ory/fosite/handler/pkce" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" + + "go.pinniped.dev/internal/constable" + "go.pinniped.dev/internal/fositestorage/authorizationcode" + "go.pinniped.dev/internal/fositestorage/openidconnect" + "go.pinniped.dev/internal/fositestorage/pkce" +) + +const errKubeStorageNotImplemented = constable.Error("KubeStorage does not implement this method. It should not have been called.") + +type KubeStorage struct { + authorizationCodeStorage oauth2.AuthorizeCodeStorage + pkceStorage fositepkce.PKCERequestStorage + oidcStorage openid.OpenIDConnectRequestStorage +} + +func NewKubeStorage(secrets corev1client.SecretInterface) *KubeStorage { + return &KubeStorage{ + authorizationCodeStorage: authorizationcode.New(secrets), + pkceStorage: pkce.New(secrets), + oidcStorage: openidconnect.New(secrets), + } +} + +func (KubeStorage) RevokeRefreshToken(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) RevokeAccessToken(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { + return nil +} + +func (KubeStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { + return nil, errKubeStorageNotImplemented +} + +func (KubeStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) { + return errKubeStorageNotImplemented +} + +func (KubeStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { + return nil +} + +func (KubeStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { + return nil, errKubeStorageNotImplemented +} + +func (KubeStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) { + return errKubeStorageNotImplemented +} + +func (k KubeStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error { + return k.oidcStorage.CreateOpenIDConnectSession(ctx, authcode, requester) +} + +func (k KubeStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) (fosite.Requester, error) { + return k.oidcStorage.GetOpenIDConnectSession(ctx, authcode, requester) +} + +func (k KubeStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error { + return k.oidcStorage.DeleteOpenIDConnectSession(ctx, authcode) +} + +func (k KubeStorage) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) { + return k.pkceStorage.GetPKCERequestSession(ctx, signature, session) +} + +func (k KubeStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error { + return k.pkceStorage.CreatePKCERequestSession(ctx, signature, requester) +} + +func (k KubeStorage) DeletePKCERequestSession(ctx context.Context, signature string) error { + return k.pkceStorage.DeletePKCERequestSession(ctx, signature) +} + +func (k KubeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, r fosite.Requester) (err error) { + return k.authorizationCodeStorage.CreateAuthorizeCodeSession(ctx, signature, r) +} + +func (k KubeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, s fosite.Session) (request fosite.Requester, err error) { + return k.authorizationCodeStorage.GetAuthorizeCodeSession(ctx, signature, s) +} + +func (k KubeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) { + return k.authorizationCodeStorage.InvalidateAuthorizeCodeSession(ctx, signature) +} + +func (KubeStorage) GetClient(_ context.Context, id string) (fosite.Client, error) { + client := PinnipedCLIOIDCClient() + if client.ID == id { + return client, nil + } + return nil, fosite.ErrNotFound +} + +func (KubeStorage) ClientAssertionJWTValid(_ context.Context, _ string) error { + return errKubeStorageNotImplemented +} + +func (KubeStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error { + return errKubeStorageNotImplemented +} diff --git a/internal/oidc/nullstorage.go b/internal/oidc/nullstorage.go index 3767f889..3dcd7a06 100644 --- a/internal/oidc/nullstorage.go +++ b/internal/oidc/nullstorage.go @@ -12,16 +12,16 @@ import ( "go.pinniped.dev/internal/constable" ) -const errNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.") +const errNullStorageNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.") type NullStorage struct{} func (NullStorage) RevokeRefreshToken(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) RevokeAccessToken(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { @@ -29,11 +29,11 @@ func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosi } func (NullStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { @@ -41,11 +41,11 @@ func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosit } func (NullStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error { @@ -53,15 +53,15 @@ func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fos } func (NullStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error { @@ -69,7 +69,7 @@ func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosit } func (NullStorage) DeletePKCERequestSession(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Requester) (err error) { @@ -77,11 +77,11 @@ func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fos } func (NullStorage) GetAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { - return nil, errNotImplemented + return nil, errNullStorageNotImplemented } func (NullStorage) InvalidateAuthorizeCodeSession(_ context.Context, _ string) (err error) { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error) { @@ -93,9 +93,9 @@ func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error } func (NullStorage) ClientAssertionJWTValid(_ context.Context, _ string) error { - return errNotImplemented + return errNullStorageNotImplemented } func (NullStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error { - return errNotImplemented + return errNullStorageNotImplemented } diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 508ae334..14f9a725 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -10,15 +10,72 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" + + "go.pinniped.dev/internal/oidc/csrftoken" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/pkce" ) const ( WellKnownEndpointPath = "/.well-known/openid-configuration" AuthorizationEndpointPath = "/oauth2/authorize" TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential + CallbackEndpointPath = "/callback" JWKSEndpointPath = "/jwks.json" ) +const ( + // Just in case we need to make a breaking change to the format of the upstream state param, + // we are including a format version number. This gives the opportunity for a future version of Pinniped + // to have the consumer of this format decide to reject versions that it doesn't understand. + UpstreamStateParamFormatVersion = "1" + + // The `name` passed to the encoder for encoding the upstream state param value. This name is short + // because it will be encoded into the upstream state param value and we're trying to keep that small. + UpstreamStateParamEncodingName = "s" + + // CSRFCookieName is the name of the browser cookie which shall hold our CSRF value. + // The `__Host` prefix has a special meaning. See: + // https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes. + CSRFCookieName = "__Host-pinniped-csrf" + + // CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF + // cookie contents. + CSRFCookieEncodingName = "csrf" +) + +// Encoder is the encoding side of the securecookie.Codec interface. +type Encoder interface { + Encode(name string, value interface{}) (string, error) +} + +// Decoder is the decoding side of the securecookie.Codec interface. +type Decoder interface { + Decode(name, value string, into interface{}) error +} + +// Codec is both the encoding and decoding sides of the securecookie.Codec interface. It is +// interface'd here so that we properly wrap the securecookie dependency. +type Codec interface { + Encoder + Decoder +} + +// UpstreamStateParamData is the format of the state parameter that we use when we communicate to an +// upstream OIDC provider. +// +// Keep the JSON to a minimal size because the upstream provider could impose size limitations on +// the state param. +type UpstreamStateParamData struct { + AuthParams string `json:"p"` + UpstreamName string `json:"u"` + Nonce nonce.Nonce `json:"n"` + CSRFToken csrftoken.CSRFToken `json:"c"` + PKCECode pkce.Code `json:"k"` + FormatVersion string `json:"v"` +} + func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { return &fosite.DefaultOpenIDConnectClient{ DefaultClient: &fosite.DefaultClient{ @@ -34,8 +91,8 @@ func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { } func FositeOauth2Helper( - issuerURL string, - oauthStore fosite.Storage, + oauthStore interface{}, + issuer string, hmacSecretOfLengthAtLeast32 []byte, jwtSigningKey *ecdsa.PrivateKey, ) fosite.OAuth2Provider { @@ -47,7 +104,7 @@ func FositeOauth2Helper( RefreshTokenLifespan: 16 * time.Hour, // long enough for a single workday - IDTokenIssuer: issuerURL, + IDTokenIssuer: issuer, TokenURL: "", // TODO set once we have this endpoint written ScopeStrategy: fosite.ExactScopeStrategy, // be careful and only support exact string matching for scopes @@ -75,3 +132,7 @@ func FositeOauth2Helper( compose.OAuth2PKCEFactory, ) } + +type IDPListGetter interface { + GetIDPList() []provider.UpstreamOIDCIdentityProviderI +} diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go new file mode 100644 index 00000000..5b214e5c --- /dev/null +++ b/internal/oidc/oidctestutil/oidc.go @@ -0,0 +1,129 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package oidctestutil + +import ( + "context" + "net/url" + + "golang.org/x/oauth2" + + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" +) + +// Test helpers for the OIDC package. + +// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to +// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc(). +type ExchangeAuthcodeAndValidateTokenArgs struct { + Ctx context.Context + Authcode string + PKCECodeVerifier pkce.Code + ExpectedIDTokenNonce nonce.Nonce + RedirectURI string +} + +type TestUpstreamOIDCIdentityProvider struct { + Name string + ClientID string + AuthorizationURL url.URL + UsernameClaim string + GroupsClaim string + Scopes []string + ExchangeAuthcodeAndValidateTokensFunc func( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, + ) (oidctypes.Token, map[string]interface{}, error) + + exchangeAuthcodeAndValidateTokensCallCount int + exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs +} + +func (u *TestUpstreamOIDCIdentityProvider) GetName() string { + return u.Name +} + +func (u *TestUpstreamOIDCIdentityProvider) GetClientID() string { + return u.ClientID +} + +func (u *TestUpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL { + return &u.AuthorizationURL +} + +func (u *TestUpstreamOIDCIdentityProvider) GetScopes() []string { + return u.Scopes +} + +func (u *TestUpstreamOIDCIdentityProvider) GetUsernameClaim() string { + return u.UsernameClaim +} + +func (u *TestUpstreamOIDCIdentityProvider) GetGroupsClaim() string { + return u.GroupsClaim +} + +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, + redirectURI string, +) (oidctypes.Token, map[string]interface{}, error) { + if u.exchangeAuthcodeAndValidateTokensArgs == nil { + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + } + u.exchangeAuthcodeAndValidateTokensCallCount++ + u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{ + Ctx: ctx, + Authcode: authcode, + PKCECodeVerifier: pkceCodeVerifier, + ExpectedIDTokenNonce: expectedIDTokenNonce, + RedirectURI: redirectURI, + }) + return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) +} + +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int { + return u.exchangeAuthcodeAndValidateTokensCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs { + if u.exchangeAuthcodeAndValidateTokensArgs == nil { + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + } + return u.exchangeAuthcodeAndValidateTokensArgs[call] +} + +func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + panic("implement me") +} + +func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { + idpProvider := provider.NewDynamicUpstreamIDPProvider() + upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders)) + for i := range upstreamOIDCIdentityProviders { + upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i]) + } + idpProvider.SetIDPList(upstreams) + return idpProvider +} + +// Declare a separate type from the production code to ensure that the state param's contents was serialized +// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of +// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality +// assertions about the redirect URL in this test. +type ExpectedUpstreamStateParamFormat struct { + P string `json:"p"` + U string `json:"u"` + N string `json:"n"` + C string `json:"c"` + K string `json:"k"` + V string `json:"v"` +} diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index bb26cef2..be25ffe8 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -4,48 +4,73 @@ package provider import ( + "context" "net/url" "sync" + + "golang.org/x/oauth2" + + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" ) -type UpstreamOIDCIdentityProvider struct { +type UpstreamOIDCIdentityProviderI interface { // A name for this upstream provider, which will be used as a component of the path for the callback endpoint // hosted by the Supervisor. - Name string + GetName() string - // The Oauth client ID registered with the upstream provider to be used in the authorization flow. - ClientID string + // The Oauth client ID registered with the upstream provider to be used in the authorization code flow. + GetClientID() string // The Authorization Endpoint fetched from discovery. - AuthorizationURL url.URL + GetAuthorizationURL() *url.URL // Scopes to request in authorization flow. - Scopes []string + GetScopes() []string + + // ID Token username claim name. May return empty string, in which case we will use some reasonable defaults. + GetUsernameClaim() string + + // ID Token groups claim name. May return empty string, in which case we won't try to read groups from the upstream provider. + GetGroupsClaim() string + + // Performs upstream OIDC authorization code exchange and token validation. + // Returns the validated raw tokens as well as the parsed claims of the ID token. + ExchangeAuthcodeAndValidateTokens( + ctx context.Context, + authcode string, + pkceCodeVerifier pkce.Code, + expectedIDTokenNonce nonce.Nonce, + redirectURI string, + ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) + + ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) } type DynamicUpstreamIDPProvider interface { - SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) - GetIDPList() []UpstreamOIDCIdentityProvider + SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) + GetIDPList() []UpstreamOIDCIdentityProviderI } type dynamicUpstreamIDPProvider struct { - oidcProviders []UpstreamOIDCIdentityProvider + oidcProviders []UpstreamOIDCIdentityProviderI mutex sync.RWMutex } func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider { return &dynamicUpstreamIDPProvider{ - oidcProviders: []UpstreamOIDCIdentityProvider{}, + oidcProviders: []UpstreamOIDCIdentityProviderI{}, } } -func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) { +func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) { p.mutex.Lock() // acquire a write lock defer p.mutex.Unlock() p.oidcProviders = oidcIDPs } -func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProvider { +func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProviderI { p.mutex.RLock() // acquire a read lock defer p.mutex.RUnlock() return p.oidcProviders diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 9414b5bc..1d80a0fc 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -9,9 +9,11 @@ import ( "sync" "github.com/gorilla/securecookie" + corev1client "k8s.io/client-go/kubernetes/typed/core/v1" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/auth" + "go.pinniped.dev/internal/oidc/callback" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/jwks" @@ -30,19 +32,26 @@ type Manager struct { providerHandlers map[string]http.Handler // map of all routes for all providers nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data - idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs + idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs + secretsClient corev1client.SecretInterface } // NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. -func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager { +func NewManager( + nextHandler http.Handler, + dynamicJWKSProvider jwks.DynamicJWKSProvider, + idpListGetter oidc.IDPListGetter, + secretsClient corev1client.SecretInterface, +) *Manager { return &Manager{ providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler, dynamicJWKSProvider: dynamicJWKSProvider, idpListGetter: idpListGetter, + secretsClient: secretsClient, } } @@ -62,20 +71,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { m.providerHandlers = make(map[string]http.Handler) for _, incomingProvider := range oidcProviders { - wellKnownURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.WellKnownEndpointPath - m.providerHandlers[wellKnownURL] = discovery.NewHandler(incomingProvider.Issuer()) + issuer := incomingProvider.Issuer() + issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() - jwksURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.JWKSEndpointPath - m.providerHandlers[jwksURL] = jwks.NewHandler(incomingProvider.Issuer(), m.dynamicJWKSProvider) + fositeHMACSecretForThisProvider := []byte("some secret - must have at least 32 bytes") // TODO replace this secret // Use NullStorage for the authorize endpoint because we do not actually want to store anything until // the upstream callback endpoint is called later. - oauthHelper := oidc.FositeOauth2Helper( - incomingProvider.Issuer(), - oidc.NullStorage{}, - []byte("some secret - must have at least 32 bytes"), // TODO replace this secret - nil, // TODO: inject me properly - ) + oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, fositeHMACSecretForThisProvider, nil) + + // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage. + oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, fositeHMACSecretForThisProvider, nil) // TODO use different codecs for the state and the cookie, because: // 1. we would like to state to have an embedded expiration date while the cookie does not need that @@ -86,10 +92,30 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { var encoder = securecookie.New(encoderHashKey, encoderBlockKey) encoder.SetSerializer(securecookie.JSONEncoder{}) - authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath - m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) + m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer) - plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) + m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider) + + m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler( + issuer, + m.idpListGetter, + oauthHelperWithNullStorage, + csrftoken.Generate, + pkce.Generate, + nonce.Generate, + encoder, + encoder, + ) + + m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler( + m.idpListGetter, + oauthHelperWithKubeStorage, + encoder, + encoder, + issuer+oidc.CallbackEndpointPath, + ) + + plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index bd1a5d4f..a3f8090d 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -4,6 +4,7 @@ package manager import ( + "context" "encoding/json" "io/ioutil" "net/http" @@ -15,12 +16,17 @@ import ( "github.com/sclevine/spec" "github.com/stretchr/testify/require" "gopkg.in/square/go-jose.v2" + "k8s.io/client-go/kubernetes/fake" "go.pinniped.dev/internal/here" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/jwks" + "go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" ) func TestManager(t *testing.T) { @@ -31,6 +37,7 @@ func TestManager(t *testing.T) { nextHandler http.HandlerFunc fallbackHandlerWasCalled bool dynamicJWKSProvider jwks.DynamicJWKSProvider + kubeClient *fake.Clientset ) const ( @@ -41,6 +48,7 @@ func TestManager(t *testing.T) { issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2KeyID = "issuer2-key" upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" + downstreamRedirectURL = "http://127.0.0.1:12345/callback" ) newGetRequest := func(url string) *http.Request { @@ -64,7 +72,7 @@ func TestManager(t *testing.T) { r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer) } - requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) { + requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) { recorder := httptest.NewRecorder() subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) @@ -79,6 +87,58 @@ func TestManager(t *testing.T) { "actual location %s did not start with expected prefix %s", actualLocation, expectedRedirectLocationPrefix, ) + + parsedLocation, err := url.Parse(actualLocation) + r.NoError(err) + redirectStateParam := parsedLocation.Query().Get("state") + r.NotEmpty(redirectStateParam) + + cookieValueAndDirectivesSplit := strings.SplitN(recorder.Header().Get("Set-Cookie"), ";", 2) + r.Len(cookieValueAndDirectivesSplit, 2) + cookieKeyValueSplit := strings.Split(cookieValueAndDirectivesSplit[0], "=") + r.Len(cookieKeyValueSplit, 2) + csrfCookieName := cookieKeyValueSplit[0] + r.Equal("__Host-pinniped-csrf", csrfCookieName) + csrfCookieValue := cookieKeyValueSplit[1] + r.NotEmpty(csrfCookieValue) + + // Return the important parts of the response so we can use them in our next request to the callback endpoint + return csrfCookieValue, redirectStateParam + } + + requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix, csrfCookieValue string) { + recorder := httptest.NewRecorder() + + numberOfKubeActionsBeforeThisRequest := len(kubeClient.Actions()) + + getRequest := newGetRequest(requestIssuer + oidc.CallbackEndpointPath + requestURLSuffix) + getRequest.AddCookie(&http.Cookie{ + Name: "__Host-pinniped-csrf", + Value: csrfCookieValue, + }) + subject.ServeHTTP(recorder, getRequest) + + r.False(fallbackHandlerWasCalled) + + // Check just enough of the response to ensure that we wired up the callback endpoint correctly. + // The endpoint's own unit tests cover everything else. + r.Equal(http.StatusFound, recorder.Code) + actualLocation := recorder.Header().Get("Location") + r.True( + strings.HasPrefix(actualLocation, downstreamRedirectURL), + "actual location %s did not start with expected prefix %s", + actualLocation, downstreamRedirectURL, + ) + parsedLocation, err := url.Parse(actualLocation) + r.NoError(err) + actualLocationQueryParams := parsedLocation.Query() + r.Contains(actualLocationQueryParams, "code") + r.Equal("openid", actualLocationQueryParams.Get("scope")) + r.Equal("some-state-value-that-is-32-byte", actualLocationQueryParams.Get("state")) + + // Make sure that we wired up the callback endpoint to use kube storage for fosite sessions. + r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3, + "did not perform any kube actions during the callback request, but should have") } requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { @@ -107,17 +167,27 @@ func TestManager(t *testing.T) { parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) r.NoError(err) - idpListGetter := provider.NewDynamicUpstreamIDPProvider() - idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{ - { - Name: "test-idp", - ClientID: "test-client-id", - AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, - Scopes: []string{"test-scope"}, + idpListGetter := oidctestutil.NewIDPListGetter(&oidctestutil.TestUpstreamOIDCIdentityProvider{ + Name: "test-idp", + ClientID: "test-client-id", + AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, + Scopes: []string{"test-scope"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + return oidctypes.Token{}, + map[string]interface{}{ + "iss": "https://some-issuer.com", + "sub": "some-subject", + "username": "test-username", + "groups": "test-group1", + }, + nil }, }) - subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) + kubeClient = fake.NewSimpleClientset() + secretsClient := kubeClient.CoreV1().Secrets("some-namespace") + + subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, secretsClient) }) when("given no providers via SetProviders()", func() { @@ -164,7 +234,6 @@ func TestManager(t *testing.T) { requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) - authRedirectURI := "http://127.0.0.1/callback" authRequestParams := "?" + url.Values{ "response_type": []string{"code"}, "scope": []string{"openid profile email"}, @@ -173,7 +242,7 @@ func TestManager(t *testing.T) { "nonce": []string{"some-nonce-value"}, "code_challenge": []string{"some-challenge"}, "code_challenge_method": []string{"S256"}, - "redirect_uri": []string{authRedirectURI}, + "redirect_uri": []string{downstreamRedirectURL}, }.Encode() requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) @@ -181,7 +250,20 @@ func TestManager(t *testing.T) { // Hostnames are case-insensitive, so test that we can handle that. requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) - requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + csrfCookieValue, upstreamStateParam := + requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + + callbackRequestParams := "?" + url.Values{ + "code": []string{"some-fake-code"}, + "state": []string{upstreamStateParam}, + }.Encode() + + requireCallbackRequestToBeHandled(issuer1, callbackRequestParams, csrfCookieValue) + requireCallbackRequestToBeHandled(issuer2, callbackRequestParams, csrfCookieValue) + + // // Hostnames are case-insensitive, so test that we can handle that. + requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams, csrfCookieValue) + requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams, csrfCookieValue) } when("given some valid providers via SetProviders()", func() { diff --git a/internal/testutil/assertions.go b/internal/testutil/assertions.go new file mode 100644 index 00000000..77247602 --- /dev/null +++ b/internal/testutil/assertions.go @@ -0,0 +1,24 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func RequireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) { + require.InDeltaf(t, + float64(t1.UnixNano()), + float64(t2.UnixNano()), + float64(delta.Nanoseconds()), + "expected %s and %s to be < %s apart, but they are %s apart", + t1.Format(time.RFC3339Nano), + t2.Format(time.RFC3339Nano), + delta.String(), + t1.Sub(t2).String(), + ) +} diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go new file mode 100644 index 00000000..a789cb85 --- /dev/null +++ b/internal/upstreamoidc/upstreamoidc.go @@ -0,0 +1,117 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions. +package upstreamoidc + +import ( + "context" + "net/http" + "net/url" + + "github.com/coreos/go-oidc" + "golang.org/x/oauth2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" + "go.pinniped.dev/pkg/oidcclient/pkce" +) + +func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { + return &ProviderConfig{Config: config, Provider: provider, Client: client} +} + +// ProviderConfig holds the active configuration of an upstream OIDC provider. +type ProviderConfig struct { + Name string + UsernameClaim string + GroupsClaim string + Config *oauth2.Config + Provider interface { + Verifier(*oidc.Config) *oidc.IDTokenVerifier + } + Client *http.Client +} + +func (p *ProviderConfig) GetName() string { + return p.Name +} + +func (p *ProviderConfig) GetClientID() string { + return p.Config.ClientID +} + +func (p *ProviderConfig) GetAuthorizationURL() *url.URL { + result, _ := url.Parse(p.Config.Endpoint.AuthURL) + return result +} + +func (p *ProviderConfig) GetScopes() []string { + return p.Config.Scopes +} + +func (p *ProviderConfig) GetUsernameClaim() string { + return p.UsernameClaim +} + +func (p *ProviderConfig) GetGroupsClaim() string { + return p.GroupsClaim +} + +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) { + tok, err := p.Config.Exchange( + oidc.ClientContext(ctx, p.Client), + authcode, + pkceCodeVerifier.Verifier(), + oauth2.SetAuthURLParam("redirect_uri", redirectURI), + ) + if err != nil { + return oidctypes.Token{}, nil, err + } + + return p.ValidateToken(ctx, tok, expectedIDTokenNonce) +} + +func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { + idTok, hasIDTok := tok.Extra("id_token").(string) + if !hasIDTok { + return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") + } + validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok) + if err != nil { + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + } + if validated.AccessTokenHash != "" { + if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + } + } + if expectedIDTokenNonce != "" { + if err := expectedIDTokenNonce.Validate(validated); err != nil { + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + } + } + + var validatedClaims map[string]interface{} + if err := validated.Claims(&validatedClaims); err != nil { + return oidctypes.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) + } + + return oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ + Token: tok.AccessToken, + Type: tok.TokenType, + Expiry: metav1.NewTime(tok.Expiry), + }, + RefreshToken: &oidctypes.RefreshToken{ + Token: tok.RefreshToken, + }, + IDToken: &oidctypes.IDToken{ + Token: idTok, + Expiry: metav1.NewTime(validated.Expiry), + }, + }, validatedClaims, nil +} diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go new file mode 100644 index 00000000..541d502f --- /dev/null +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -0,0 +1,223 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package upstreamoidc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/coreos/go-oidc" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "gopkg.in/square/go-jose.v2" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" +) + +func TestProviderConfig(t *testing.T) { + t.Run("getters get", func(t *testing.T) { + p := ProviderConfig{ + Name: "test-name", + UsernameClaim: "test-username-claim", + GroupsClaim: "test-groups-claim", + Config: &oauth2.Config{ + ClientID: "test-client-id", + Endpoint: oauth2.Endpoint{AuthURL: "https://example.com"}, + Scopes: []string{"scope1", "scope2"}, + }, + } + require.Equal(t, "test-name", p.GetName()) + require.Equal(t, "test-client-id", p.GetClientID()) + require.Equal(t, "https://example.com", p.GetAuthorizationURL().String()) + require.ElementsMatch(t, []string{"scope1", "scope2"}, p.GetScopes()) + require.Equal(t, "test-username-claim", p.GetUsernameClaim()) + require.Equal(t, "test-groups-claim", p.GetGroupsClaim()) + }) + + const ( + // Test JWTs generated with https://smallstep.com/docs/cli/crypto/jwt/: + + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" + invalidAccessTokenHashIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg" //nolint: gosec + + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" + invalidNonceIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ" //nolint: gosec + + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"foo": "bar", "bat": "baz"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" + validIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImJhdCI6ImJheiIsImZvbyI6ImJhciIsImlhdCI6MTYwNjc2ODU5MywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDY3Njg1OTMsInN1YiI6InRlc3QtdXNlciJ9.DuqVZ7pGhHqKz7gNr4j2W1s1N8YrSltktH4wW19L4oD1OE2-O72jAnNj5xdjilsa8l7h9ox-5sMF0Tkh3BdRlHQK9dEtNm9tW-JreUnWJ3LCqUs-LZp4NG7edvq2sH_1Bn7O2_NQV51s8Pl04F60CndjQ4NM-6WkqDQTKyY6vJXU7idvM-6TM2HJZK-Na88cOJ9KIK37tL5DhcbsHVF47Dq8uPZ0KbjNQjJLAIi_1GeQBgc6yJhDUwRY4Xu6S0dtTHA6xTI8oSXoamt4bkViEHfJBp97LZQiNz8mku5pVc0aNwP1p4hMHxRHhLXrJjbh-Hx4YFjxtOnIq9t1mHlD4A" //nolint: gosec + ) + + tests := []struct { + name string + authCode string + expectNonce nonce.Nonce + returnIDTok string + wantErr string + wantToken oidctypes.Token + wantClaims map[string]interface{} + }{ + { + name: "exchange fails with network error", + authCode: "invalid-auth-code", + wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", + }, + { + name: "missing ID token", + authCode: "valid", + wantErr: "received response missing ID token", + }, + { + name: "invalid ID token", + authCode: "valid", + returnIDTok: "invalid-jwt", + wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", + }, + { + name: "invalid access token hash", + authCode: "valid", + returnIDTok: invalidAccessTokenHashIDToken, + wantErr: "received invalid ID token: access token hash does not match value in ID token", + }, + { + name: "invalid nonce", + authCode: "valid", + expectNonce: "test-nonce", + returnIDTok: invalidNonceIDToken, + wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`, + }, + { + name: "invalid nonce but not checked", + authCode: "valid", + expectNonce: "", + returnIDTok: invalidNonceIDToken, + wantToken: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ + Token: "test-access-token", + Expiry: metav1.Time{}, + }, + RefreshToken: &oidctypes.RefreshToken{ + Token: "test-refresh-token", + }, + IDToken: &oidctypes.IDToken{ + Token: invalidNonceIDToken, + Expiry: metav1.Time{}, + }, + }, + }, + { + name: "valid", + authCode: "valid", + returnIDTok: validIDToken, + wantToken: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ + Token: "test-access-token", + Expiry: metav1.Time{}, + }, + RefreshToken: &oidctypes.RefreshToken{ + Token: "test-refresh-token", + }, + IDToken: &oidctypes.IDToken{ + Token: validIDToken, + Expiry: metav1.Time{}, + }, + }, + wantClaims: map[string]interface{}{ + "foo": "bar", + "bat": "baz", + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.NoError(t, r.ParseForm()) + require.Equal(t, "test-client-id", r.Form.Get("client_id")) + require.Equal(t, "test-pkce", r.Form.Get("code_verifier")) + require.Equal(t, "authorization_code", r.Form.Get("grant_type")) + require.NotEmpty(t, r.Form.Get("code")) + if r.Form.Get("code") != "valid" { + http.Error(w, "invalid authorization code", http.StatusForbidden) + return + } + var response struct { + oauth2.Token + IDToken string `json:"id_token,omitempty"` + } + response.AccessToken = "test-access-token" + response.RefreshToken = "test-refresh-token" + response.Expiry = time.Now().Add(time.Hour) + response.IDToken = tt.returnIDTok + w.Header().Set("content-type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(&response)) + })) + t.Cleanup(tokenServer.Close) + + p := ProviderConfig{ + Name: "test-name", + UsernameClaim: "test-username-claim", + GroupsClaim: "test-groups-claim", + Config: &oauth2.Config{ + ClientID: "test-client-id", + Endpoint: oauth2.Endpoint{ + AuthURL: "https://example.com", + TokenURL: tokenServer.URL, + AuthStyle: oauth2.AuthStyleInParams, + }, + Scopes: []string{"scope1", "scope2"}, + }, + Provider: &mockProvider{}, + } + + ctx := context.Background() + + tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + require.Equal(t, oidctypes.Token{}, tok) + require.Nil(t, claims) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantToken, tok) + + for k, v := range tt.wantClaims { + require.Equal(t, v, claims[k]) + } + }) + } +} + +// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. +func mockVerifier() *oidc.IDTokenVerifier { + mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) + mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). + AnyTimes(). + DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) { + jws, err := jose.ParseSigned(jwt) + if err != nil { + return nil, err + } + return jws.UnsafePayloadWithoutVerification(), nil + }) + + return oidc.NewVerifier("", mockKeySet, &oidc.Config{ + SkipIssuerCheck: true, + SkipExpiryCheck: true, + SkipClientIDCheck: true, + }) +} + +type mockProvider struct{} + +func (m *mockProvider) Verifier(_ *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } diff --git a/pkg/oidcclient/filesession/cachefile.go b/pkg/oidcclient/filesession/cachefile.go index 3629ca5f..9ea46bc0 100644 --- a/pkg/oidcclient/filesession/cachefile.go +++ b/pkg/oidcclient/filesession/cachefile.go @@ -17,6 +17,7 @@ import ( "sigs.k8s.io/yaml" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) var ( @@ -48,7 +49,7 @@ type ( Key oidcclient.SessionCacheKey `json:"key"` CreationTimestamp metav1.Time `json:"creationTimestamp"` LastUsedTimestamp metav1.Time `json:"lastUsedTimestamp"` - Tokens oidcclient.Token `json:"tokens"` + Tokens oidctypes.Token `json:"tokens"` } ) diff --git a/pkg/oidcclient/filesession/cachefile_test.go b/pkg/oidcclient/filesession/cachefile_test.go index 4a30ae74..b1e1c984 100644 --- a/pkg/oidcclient/filesession/cachefile_test.go +++ b/pkg/oidcclient/filesession/cachefile_test.go @@ -13,6 +13,7 @@ import ( "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) // validSession should be the same data as `testdata/valid.yaml`. @@ -28,17 +29,17 @@ var validSession = sessionCache{ }, CreationTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 42, 7, 0, time.UTC).Local()), LastUsedTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 45, 31, 0, time.UTC).Local()), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 46, 30, 0, time.UTC).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -140,8 +141,8 @@ func TestNormalized(t *testing.T) { // ID token is empty, but not nil. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - IDToken: &oidcclient.IDToken{ + Tokens: oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "", Expiry: metav1.NewTime(now.Add(1 * time.Minute)), }, @@ -150,8 +151,8 @@ func TestNormalized(t *testing.T) { // ID token is expired. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - IDToken: &oidcclient.IDToken{ + Tokens: oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(-1 * time.Minute)), }, @@ -160,8 +161,8 @@ func TestNormalized(t *testing.T) { // Access token is empty, but not nil. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "", Expiry: metav1.NewTime(now.Add(1 * time.Minute)), }, @@ -170,8 +171,8 @@ func TestNormalized(t *testing.T) { // Access token is expired. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Expiry: metav1.NewTime(now.Add(-1 * time.Minute)), }, @@ -180,8 +181,8 @@ func TestNormalized(t *testing.T) { // Refresh token is empty, but not nil. { LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "", }, }, @@ -189,8 +190,8 @@ func TestNormalized(t *testing.T) { // Session has a refresh token but it hasn't been used in >90 days. { LastUsedTimestamp: metav1.NewTime(now.AddDate(-1, 0, 0)), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -199,8 +200,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token2", }, }, @@ -208,8 +209,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token1", }, }, @@ -223,8 +224,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token1", }, }, @@ -232,8 +233,8 @@ func TestNormalized(t *testing.T) { { CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now), - Tokens: oidcclient.Token{ - RefreshToken: &oidcclient.RefreshToken{ + Tokens: oidctypes.Token{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token2", }, }, diff --git a/pkg/oidcclient/filesession/filesession.go b/pkg/oidcclient/filesession/filesession.go index 47e0f761..151fde71 100644 --- a/pkg/oidcclient/filesession/filesession.go +++ b/pkg/oidcclient/filesession/filesession.go @@ -16,6 +16,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) const ( @@ -65,14 +66,14 @@ type Cache struct { } // GetToken looks up the cached data for the given parameters. It may return nil if no valid matching session is cached. -func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token { +func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidctypes.Token { // If the cache file does not exist, exit immediately with no error log if _, err := os.Stat(c.path); errors.Is(err, os.ErrNotExist) { return nil } // Read the cache and lookup the matching entry. If one exists, update its last used timestamp and return it. - var result *oidcclient.Token + var result *oidctypes.Token c.withCache(func(cache *sessionCache) { if entry := cache.lookup(key); entry != nil { result = &entry.Tokens @@ -84,7 +85,7 @@ func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token { // PutToken stores the provided token into the session cache under the given parameters. It does not return an error // but may silently fail to update the session cache. -func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidcclient.Token) { +func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidctypes.Token) { // Create the cache directory if it does not exist. if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil && !errors.Is(err, os.ErrExist) { c.errReporter(fmt.Errorf("could not create session cache directory: %w", err)) diff --git a/pkg/oidcclient/filesession/filesession_test.go b/pkg/oidcclient/filesession/filesession_test.go index 2d78a128..2ba7c55b 100644 --- a/pkg/oidcclient/filesession/filesession_test.go +++ b/pkg/oidcclient/filesession/filesession_test.go @@ -17,6 +17,7 @@ import ( "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient" + "go.pinniped.dev/pkg/oidcclient/oidctypes" ) func TestNew(t *testing.T) { @@ -38,7 +39,7 @@ func TestGetToken(t *testing.T) { trylockFunc func(*testing.T) error unlockFunc func(*testing.T) error key oidcclient.SessionCacheKey - want *oidcclient.Token + want *oidctypes.Token wantErrors []string wantTestFile func(t *testing.T, tmp string) }{ @@ -99,17 +100,17 @@ func TestGetToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -137,17 +138,17 @@ func TestGetToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -161,17 +162,17 @@ func TestGetToken(t *testing.T) { RedirectURI: "http://localhost:0/callback", }, wantErrors: []string{}, - want: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + want: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "test-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", }, }, @@ -219,7 +220,7 @@ func TestPutToken(t *testing.T) { name string makeTestFile func(t *testing.T, tmp string) key oidcclient.SessionCacheKey - token *oidcclient.Token + token *oidctypes.Token wantErrors []string wantTestFile func(t *testing.T, tmp string) }{ @@ -245,17 +246,17 @@ func TestPutToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "old-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "old-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "old-refresh-token", }, }, @@ -269,17 +270,17 @@ func TestPutToken(t *testing.T) { Scopes: []string{"email", "offline_access", "openid", "profile"}, RedirectURI: "http://localhost:0/callback", }, - token: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + token: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, @@ -288,17 +289,17 @@ func TestPutToken(t *testing.T) { require.NoError(t, err) require.Len(t, cache.Sessions, 1) require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds()) - require.Equal(t, oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + require.Equal(t, oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, cache.Sessions[0].Tokens) @@ -317,17 +318,17 @@ func TestPutToken(t *testing.T) { }, CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), - Tokens: oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + Tokens: oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "old-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "old-id-token", Expiry: metav1.NewTime(now.Add(1 * time.Hour)), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "old-refresh-token", }, }, @@ -341,17 +342,17 @@ func TestPutToken(t *testing.T) { Scopes: []string{"email", "offline_access", "openid", "profile"}, RedirectURI: "http://localhost:0/callback", }, - token: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + token: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, @@ -360,17 +361,17 @@ func TestPutToken(t *testing.T) { require.NoError(t, err) require.Len(t, cache.Sessions, 2) require.Less(t, time.Since(cache.Sessions[1].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds()) - require.Equal(t, oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + require.Equal(t, oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, cache.Sessions[1].Tokens) @@ -389,17 +390,17 @@ func TestPutToken(t *testing.T) { Scopes: []string{"email", "offline_access", "openid", "profile"}, RedirectURI: "http://localhost:0/callback", }, - token: &oidcclient.Token{ - AccessToken: &oidcclient.AccessToken{ + token: &oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{ Token: "new-access-token", Type: "Bearer", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - IDToken: &oidcclient.IDToken{ + IDToken: &oidctypes.IDToken{ Token: "new-id-token", Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), }, - RefreshToken: &oidcclient.RefreshToken{ + RefreshToken: &oidctypes.RefreshToken{ Token: "new-refresh-token", }, }, diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 0898f944..0df34622 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -16,11 +16,13 @@ import ( "github.com/coreos/go-oidc" "github.com/pkg/browser" "golang.org/x/oauth2" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/state" ) @@ -51,24 +53,24 @@ type handlerState struct { callbackPath string // Generated parameters of a login flow. - idTokenVerifier *oidc.IDTokenVerifier - oauth2Config *oauth2.Config - state state.State - nonce nonce.Nonce - pkce pkce.Code + provider *oidc.Provider + oauth2Config *oauth2.Config + state state.State + nonce nonce.Nonce + pkce pkce.Code // External calls for things. generateState func() (state.State, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error - oidcDiscover func(context.Context, string) (discoveryI, error) + getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI callbacks chan callbackResult } type callbackResult struct { - token *Token + token *oidctypes.Token err error } @@ -87,10 +89,11 @@ func WithContext(ctx context.Context) Option { // WithListenPort specifies a TCP listen port on localhost, which will be used for the redirect_uri and to handle the // authorization code callback. By default, a random high port will be chosen which requires the authorization server // to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252: -// The authorization server MUST allow any port to be specified at the -// time of the request for loopback IP redirect URIs, to accommodate -// clients that obtain an available ephemeral port from the operating -// system at the time of the request. +// +// The authorization server MUST allow any port to be specified at the +// time of the request for loopback IP redirect URIs, to accommodate +// clients that obtain an available ephemeral port from the operating +// system at the time of the request. func WithListenPort(port uint16) Option { return func(h *handlerState) error { h.listenAddr = fmt.Sprintf("localhost:%d", port) @@ -116,6 +119,19 @@ func WithBrowserOpen(openURL func(url string) error) Option { } } +// SessionCacheKey contains the data used to select a valid session cache entry. +type SessionCacheKey struct { + Issuer string `json:"issuer"` + ClientID string `json:"clientID"` + Scopes []string `json:"scopes"` + RedirectURI string `json:"redirect_uri"` +} + +type SessionCache interface { + GetToken(SessionCacheKey) *oidctypes.Token + PutToken(SessionCacheKey, *oidctypes.Token) +} + // WithSessionCache sets the session cache backend for storing and retrieving previously-issued ID tokens and refresh tokens. func WithSessionCache(cache SessionCache) Option { return func(h *handlerState) error { @@ -135,16 +151,11 @@ func WithClient(httpClient *http.Client) Option { // nopCache is a SessionCache that doesn't actually do anything. type nopCache struct{} -func (*nopCache) GetToken(SessionCacheKey) *Token { return nil } -func (*nopCache) PutToken(SessionCacheKey, *Token) {} - -type discoveryI interface { - Endpoint() oauth2.Endpoint - Verifier(*oidc.Config) *oidc.IDTokenVerifier -} +func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil } +func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {} // Login performs an OAuth2/OIDC authorization code login using a localhost listener. -func Login(issuer string, clientID string, opts ...Option) (*Token, error) { +func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) { h := handlerState{ issuer: issuer, clientID: clientID, @@ -161,9 +172,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { generateNonce: nonce.Generate, generatePKCE: pkce.Generate, openURL: browser.OpenURL, - oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { - return oidc.NewProvider(ctx, iss) - }, + getProvider: upstreamoidc.New, } for _, opt := range opts { if err := opt(&h); err != nil { @@ -208,16 +217,15 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { } // Perform OIDC discovery. - discovered, err := h.oidcDiscover(h.ctx, h.issuer) + h.provider, err = oidc.NewProvider(h.ctx, h.issuer) if err != nil { return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) } - h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID}) // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. h.oauth2Config = &oauth2.Config{ ClientID: h.clientID, - Endpoint: discovered.Endpoint(), + Endpoint: h.provider.Endpoint(), Scopes: h.scopes, } @@ -274,7 +282,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { } } -func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) { +func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) { ctx, cancel := context.WithTimeout(ctx, refreshTimeout) defer cancel() refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token}) @@ -287,7 +295,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshT // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - return h.validateToken(ctx, refreshed, false) + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") + if err != nil { + return nil, err + } + return &token, nil } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { @@ -314,58 +326,25 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam) } - // Exchange the authorization code for access, ID, and refresh tokens. - oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier()) + // Exchange the authorization code for access, ID, and refresh tokens and perform required + // validations on the returned ID token. + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). + ExchangeAuthcodeAndValidateTokens( + r.Context(), + params.Get("code"), + h.pkce, + h.nonce, + h.oauth2Config.RedirectURL, + ) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } - // Perform required validations on the returned ID token. - token, err := h.validateToken(r.Context(), oauth2Tok, true) - if err != nil { - return err - } - - h.callbacks <- callbackResult{token: token} + h.callbacks <- callbackResult{token: &token} _, _ = w.Write([]byte("you have been logged in and may now close this tab")) return nil } -func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) { - idTok, hasIDTok := tok.Extra("id_token").(string) - if !hasIDTok { - return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") - } - validated, err := h.idTokenVerifier.Verify(ctx, idTok) - if err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - if validated.AccessTokenHash != "" { - if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - } - if checkNonce { - if err := h.nonce.Validate(validated); err != nil { - return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) - } - } - return &Token{ - AccessToken: &AccessToken{ - Token: tok.AccessToken, - Type: tok.TokenType, - Expiry: metav1.NewTime(tok.Expiry), - }, - RefreshToken: &RefreshToken{ - Token: tok.RefreshToken, - }, - IDToken: &IDToken{ - Token: idTok, - Expiry: metav1.NewTime(validated.Expiry), - }, - }, nil -} - func (h *handlerState) serve(listener net.Listener) func() { mux := http.NewServeMux() mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index b323f586..96d790ba 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -18,12 +18,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" - "gopkg.in/square/go-jose.v2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/internal/httputil/httperr" - "go.pinniped.dev/internal/mocks/mockkeyset" + "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider" + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" + "go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/state" ) @@ -31,19 +33,19 @@ import ( // mockSessionCache exists to avoid an import cycle if we generate mocks into another package. type mockSessionCache struct { t *testing.T - getReturnsToken *Token + getReturnsToken *oidctypes.Token sawGetKeys []SessionCacheKey sawPutKeys []SessionCacheKey - sawPutTokens []*Token + sawPutTokens []*oidctypes.Token } -func (m *mockSessionCache) GetToken(key SessionCacheKey) *Token { +func (m *mockSessionCache) GetToken(key SessionCacheKey) *oidctypes.Token { m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID) m.sawGetKeys = append(m.sawGetKeys, key) return m.getReturnsToken } -func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) { +func (m *mockSessionCache) PutToken(key SessionCacheKey, token *oidctypes.Token) { m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token) m.sawPutKeys = append(m.sawPutKeys, key) m.sawPutTokens = append(m.sawPutTokens, token) @@ -54,20 +56,10 @@ func TestLogin(t *testing.T) { time1Unix := int64(2075807775) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) - testToken := Token{ - AccessToken: &AccessToken{ - Token: "test-access-token", - Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), - }, - RefreshToken: &RefreshToken{ - Token: "test-refresh-token", - }, - IDToken: &IDToken{ - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above): - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775 - Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA", - Expiry: metav1.NewTime(time1.Add(2 * time.Minute)), - }, + testToken := oidctypes.Token{ + AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, + IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))}, } // Start a test server that returns 500 errors @@ -76,7 +68,7 @@ func TestLogin(t *testing.T) { })) t.Cleanup(errorServer.Close) - // Start a test server that returns a real keyset and answers refresh requests. + // Start a test server that returns a real discovery document and answers refresh requests. providerMux := http.NewServeMux() successServer := httptest.NewServer(providerMux) t.Cleanup(successServer.Close) @@ -144,7 +136,7 @@ func TestLogin(t *testing.T) { issuer string clientID string wantErr string - wantToken *Token + wantToken *oidctypes.Token }{ { name: "option error", @@ -191,8 +183,8 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity }, @@ -246,12 +238,20 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). + Return(testToken, nil, nil) + return mock + } + + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", Expiry: metav1.Now(), // less than Now() + minIDTokenValidity }, - RefreshToken: &RefreshToken{Token: "test-refresh-token"}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, }} t.Cleanup(func() { cacheKey := SessionCacheKey{ @@ -266,12 +266,6 @@ func TestLogin(t *testing.T) { require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) }) h.cache = cache - - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } return nil } }, @@ -283,12 +277,20 @@ func TestLogin(t *testing.T) { clientID: "test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). + Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error")) + return mock + } + + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", Expiry: metav1.Now(), // less than Now() + minIDTokenValidity }, - RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"}, }} t.Cleanup(func() { require.Empty(t, cache.sawPutKeys) @@ -296,16 +298,10 @@ func TestLogin(t *testing.T) { }) h.cache = cache - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } - return nil } }, - wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", + wantErr: "some validation error", }, { name: "session cache hit but refresh fails", @@ -313,12 +309,12 @@ func TestLogin(t *testing.T) { clientID: "not-the-test-client-id", opt: func(t *testing.T) Option { return func(h *handlerState) error { - cache := &mockSessionCache{t: t, getReturnsToken: &Token{ - IDToken: &IDToken{ + cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ Token: "expired-test-id-token", Expiry: metav1.Now(), // less than Now() + minIDTokenValidity }, - RefreshToken: &RefreshToken{Token: "test-refresh-token"}, + RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, }} t.Cleanup(func() { require.Empty(t, cache.sawPutKeys) @@ -326,12 +322,6 @@ func TestLogin(t *testing.T) { }) h.cache = cache - h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { - provider, err := oidc.NewProvider(ctx, iss) - require.NoError(t, err) - return &mockDiscovery{provider: provider}, nil - } - h.listenAddr = "invalid-listen-address" return nil @@ -413,7 +403,7 @@ func TestLogin(t *testing.T) { t.Cleanup(func() { require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys) - require.Equal(t, []*Token{&testToken}, cache.sawPutTokens) + require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens) }) require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h)) @@ -481,7 +471,7 @@ func TestLogin(t *testing.T) { require.NotNil(t, tok.AccessToken) require.Equal(t, want.Token, tok.AccessToken.Token) require.Equal(t, want.Type, tok.AccessToken.Type) - requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second) + testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second) } else { assert.Nil(t, tok.AccessToken) } @@ -489,7 +479,7 @@ func TestLogin(t *testing.T) { if want := tt.wantToken.IDToken; want != nil { require.NotNil(t, tok.IDToken) require.Equal(t, want.Token, tok.IDToken.Token) - requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second) + testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second) } else { assert.Nil(t, tok.IDToken) } @@ -498,11 +488,13 @@ func TestLogin(t *testing.T) { } func TestHandleAuthCodeCallback(t *testing.T) { + const testRedirectURI = "http://127.0.0.1:12324/callback" + tests := []struct { name string method string query string - returnIDTok string + opt func(t *testing.T) Option wantErr string wantHTTPStatus int }{ @@ -528,94 +520,51 @@ func TestHandleAuthCodeCallback(t *testing.T) { { name: "invalid code", query: "state=test-state&code=invalid", - wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", + wantErr: "could not complete code exchange: some exchange error", wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "missing ID token", - query: "state=test-state&code=valid", - returnIDTok: "", - wantErr: "received response missing ID token", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid ID token", - query: "state=test-state&code=valid", - returnIDTok: "invalid-jwt", - wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid access token hash", - query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg", - - wantErr: "received invalid ID token: access token hash does not match value in ID token", - wantHTTPStatus: http.StatusBadRequest, - }, - { - name: "invalid nonce", - query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ", - - wantHTTPStatus: http.StatusBadRequest, - wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`, + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) + return mock + } + return nil + } + }, }, { name: "valid", query: "state=test-state&code=valid", - - // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: - // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" - returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) + return mock + } + return nil + } + }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.NoError(t, r.ParseForm()) - require.Equal(t, "test-client-id", r.Form.Get("client_id")) - require.Equal(t, "test-pkce", r.Form.Get("code_verifier")) - require.Equal(t, "authorization_code", r.Form.Get("grant_type")) - require.NotEmpty(t, r.Form.Get("code")) - if r.Form.Get("code") != "valid" { - http.Error(w, "invalid authorization code", http.StatusForbidden) - return - } - var response struct { - oauth2.Token - IDToken string `json:"id_token,omitempty"` - } - response.AccessToken = "test-access-token" - response.Expiry = time.Now().Add(time.Hour) - response.IDToken = tt.returnIDTok - w.Header().Set("content-type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(&response)) - })) - t.Cleanup(tokenServer.Close) - h := &handlerState{ callbacks: make(chan callbackResult, 1), state: state.State("test-state"), pkce: pkce.Code("test-pkce"), nonce: nonce.Nonce("test-nonce"), - oauth2Config: &oauth2.Config{ - ClientID: "test-client-id", - RedirectURL: "http://localhost:12345/callback", - Endpoint: oauth2.Endpoint{ - TokenURL: tokenServer.URL, - AuthStyle: oauth2.AuthStyleInParams, - }, - }, - idTokenVerifier: mockVerifier(), + } + if tt.opt != nil { + require.NoError(t, tt.opt(t)(h)) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -651,47 +600,34 @@ func TestHandleAuthCodeCallback(t *testing.T) { } require.NoError(t, result.err) require.NotNil(t, result.token) - require.Equal(t, result.token.IDToken.Token, tt.returnIDTok) + require.Equal(t, result.token.IDToken.Token, "test-id-token") } }) } } -// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. -func mockVerifier() *oidc.IDTokenVerifier { - mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) - mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). - AnyTimes(). - DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) { - jws, err := jose.ParseSigned(jwt) - if err != nil { - return nil, err - } - return jws.UnsafePayloadWithoutVerification(), nil - }) - - return oidc.NewVerifier("", mockKeySet, &oidc.Config{ - SkipIssuerCheck: true, - SkipExpiryCheck: true, - SkipClientIDCheck: true, - }) +func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI { + t.Helper() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl) } -type mockDiscovery struct{ provider *oidc.Provider } +// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token. +type hasAccessTokenMatcher struct{ expected string } -func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } - -func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } - -func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) { - require.InDeltaf(t, - float64(t1.UnixNano()), - float64(t2.UnixNano()), - float64(delta.Nanoseconds()), - "expected %s and %s to be < %s apart, but they are %s apart", - t1.Format(time.RFC3339Nano), - t2.Format(time.RFC3339Nano), - delta.String(), - t1.Sub(t2).String(), - ) +func (m hasAccessTokenMatcher) Matches(arg interface{}) bool { + return arg.(*oauth2.Token).AccessToken == m.expected +} + +func (m hasAccessTokenMatcher) Got(got interface{}) string { + return got.(*oauth2.Token).AccessToken +} + +func (m hasAccessTokenMatcher) String() string { + return m.expected +} + +func HasAccessToken(expected string) gomock.Matcher { + return hasAccessTokenMatcher{expected: expected} } diff --git a/pkg/oidcclient/types.go b/pkg/oidcclient/oidctypes/oidctypes.go similarity index 69% rename from pkg/oidcclient/types.go rename to pkg/oidcclient/oidctypes/oidctypes.go index 7fbf3a3f..94f5dcc9 100644 --- a/pkg/oidcclient/types.go +++ b/pkg/oidcclient/oidctypes/oidctypes.go @@ -1,11 +1,10 @@ // Copyright 2020 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -package oidcclient +// Package oidctypes provides core data types for OIDC token structures. +package oidctypes -import ( - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" -) +import v1 "k8s.io/apimachinery/pkg/apis/meta/v1" // AccessToken is an OAuth2 access token. type AccessToken struct { @@ -16,7 +15,7 @@ type AccessToken struct { Type string `json:"type,omitempty"` // Expiry is the optional expiration time of the access token. - Expiry metav1.Time `json:"expiryTimestamp,omitempty"` + Expiry v1.Time `json:"expiryTimestamp,omitempty"` } // RefreshToken is an OAuth2 refresh token. @@ -31,7 +30,7 @@ type IDToken struct { Token string `json:"token"` // Expiry is the optional expiration time of the ID token. - Expiry metav1.Time `json:"expiryTimestamp,omitempty"` + Expiry v1.Time `json:"expiryTimestamp,omitempty"` } // Token contains the elements of an OIDC session. @@ -47,16 +46,3 @@ type Token struct { // IDToken is an OpenID Connect ID token. IDToken *IDToken `json:"id,omitempty"` } - -// SessionCacheKey contains the data used to select a valid session cache entry. -type SessionCacheKey struct { - Issuer string `json:"issuer"` - ClientID string `json:"clientID"` - Scopes []string `json:"scopes"` - RedirectURI string `json:"redirect_uri"` -} - -type SessionCache interface { - GetToken(SessionCacheKey) *Token - PutToken(SessionCacheKey, *Token) -} diff --git a/test/deploy/dex/dex.yaml b/test/deploy/dex/dex.yaml index bd078f24..6a5ecfec 100644 --- a/test/deploy/dex/dex.yaml +++ b/test/deploy/dex/dex.yaml @@ -28,8 +28,7 @@ staticClients: name: 'Pinniped Supervisor' secret: pinniped-supervisor-secret redirectURIs: - - #@ "http://127.0.0.1:" + str(data.values.ports.cli) + "/callback" - - #@ "http://[::1]:" + str(data.values.ports.cli) + "/callback" + - https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback enablePasswordDB: true staticPasswords: - username: "pinny" diff --git a/test/integration/cli_test.go b/test/integration/cli_test.go index 4753a3a6..c4c051eb 100644 --- a/test/integration/cli_test.go +++ b/test/integration/cli_test.go @@ -20,7 +20,6 @@ import ( "testing" "time" - "github.com/sclevine/agouti" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "gopkg.in/square/go-jose.v2" @@ -30,6 +29,7 @@ import ( "go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient/filesession" "go.pinniped.dev/test/library" + "go.pinniped.dev/test/library/browsertest" ) func TestCLIGetKubeconfig(t *testing.T) { @@ -108,80 +108,14 @@ func runPinnipedCLIGetKubeconfig(t *testing.T, pinnipedExe, token, namespaceName return string(output) } -type loginProviderPatterns struct { - Name string - IssuerPattern *regexp.Regexp - LoginPagePattern *regexp.Regexp - UsernameSelector string - PasswordSelector string - LoginButtonSelector string -} - -func getLoginProvider(t *testing.T) *loginProviderPatterns { - t.Helper() - issuer := library.IntegrationEnv(t).CLITestUpstream.Issuer - for _, p := range []loginProviderPatterns{ - { - Name: "Okta", - IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), - LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), - UsernameSelector: "input#okta-signin-username", - PasswordSelector: "input#okta-signin-password", - LoginButtonSelector: "input#okta-signin-submit", - }, - { - Name: "Dex", - IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`), - LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`), - UsernameSelector: "input#login", - PasswordSelector: "input#password", - LoginButtonSelector: "button#submit-login", - }, - } { - if p.IssuerPattern.MatchString(issuer) { - return &p - } - } - require.Failf(t, "could not find login provider for issuer %q", issuer) - return nil -} - func TestCLILoginOIDC(t *testing.T) { env := library.IntegrationEnv(t) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - // Find the login CSS selectors for the test issuer, or fail fast. - loginProvider := getLoginProvider(t) - // Start the browser driver. - t.Logf("opening browser driver") - caps := agouti.NewCapabilities() - if env.Proxy != "" { - t.Logf("configuring Chrome to use proxy %q", env.Proxy) - caps = caps.Proxy(agouti.ProxyConfig{ - ProxyType: "manual", - HTTPProxy: env.Proxy, - SSLProxy: env.Proxy, - NoProxy: "127.0.0.1", - }) - } - agoutiDriver := agouti.ChromeDriver( - agouti.Desired(caps), - agouti.ChromeOptions("args", []string{ - "--no-sandbox", - "--ignore-certificate-errors", - "--headless", // Comment out this line to see the tests happen in a visible browser window. - }), - // Uncomment this to see stdout/stderr from chromedriver. - // agouti.Debug, - ) - require.NoError(t, agoutiDriver.Start()) - t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) }) - page, err := agoutiDriver.NewPage(agouti.Browser("chrome")) - require.NoError(t, err) - require.NoError(t, page.Reset()) + page := browsertest.Open(t) // Build pinniped CLI. t.Logf("building CLI binary") @@ -262,28 +196,18 @@ func TestCLILoginOIDC(t *testing.T) { t.Logf("navigating to login page") require.NoError(t, page.Navigate(loginURL)) - // Expect to be redirected to the login page. - t.Logf("waiting for redirect to %s login page", loginProvider.Name) - waitForURL(t, page, loginProvider.LoginPagePattern) + // Expect to be redirected to the upstream provider and log in. + browsertest.LoginToUpstream(t, page, env.CLITestUpstream) - // Wait for the login page to be rendered. - waitForVisibleElements(t, page, loginProvider.UsernameSelector, loginProvider.PasswordSelector, loginProvider.LoginButtonSelector) - - // Fill in the username and password and click "submit". - t.Logf("logging into %s", loginProvider.Name) - require.NoError(t, page.First(loginProvider.UsernameSelector).Fill(env.CLITestUpstream.Username)) - require.NoError(t, page.First(loginProvider.PasswordSelector).Fill(env.CLITestUpstream.Password)) - require.NoError(t, page.First(loginProvider.LoginButtonSelector).Click()) - - // Wait for the login to happen and us be redirected back to a localhost callback. - t.Logf("waiting for redirect to localhost callback") + // Expect to be redirected to the localhost callback. + t.Logf("waiting for redirect to callback") callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(env.CLITestUpstream.CallbackURL) + `\?.+\z`) - waitForURL(t, page, callbackURLPattern) + browsertest.WaitForURL(t, page, callbackURLPattern) // Wait for the "pre" element that gets rendered for a `text/plain` page, and // assert that it contains the success message. t.Logf("verifying success page") - waitForVisibleElements(t, page, "pre") + browsertest.WaitForVisibleElements(t, page, "pre") msg, err := page.First("pre").Text() require.NoError(t, err) require.Equal(t, "you have been logged in and may now close this tab", msg) @@ -361,44 +285,6 @@ func TestCLILoginOIDC(t *testing.T) { require.NotEqual(t, credOutput2.Status.Token, credOutput3.Status.Token) } -func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) { - t.Helper() - require.Eventually(t, - func() bool { - for _, sel := range selectors { - vis, err := page.First(sel).Visible() - if !(err == nil && vis) { - return false - } - } - return true - }, - 10*time.Second, - 100*time.Millisecond, - ) -} - -func waitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) { - var lastURL string - require.Eventuallyf(t, - func() bool { - url, err := page.URL() - if err == nil && pat.MatchString(url) { - return true - } - if url != lastURL { - t.Logf("saw URL %s", url) - lastURL = url - } - return false - }, - 10*time.Second, - 100*time.Millisecond, - "expected to browse to %s, but never got there", - pat, - ) -} - func readAndExpectEmpty(r io.Reader) (err error) { var remainder bytes.Buffer _, err = io.Copy(&remainder, r) diff --git a/test/integration/storage_test.go b/test/integration/storage_test.go index e2f3bdf2..501099fe 100644 --- a/test/integration/storage_test.go +++ b/test/integration/storage_test.go @@ -17,7 +17,7 @@ import ( "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "go.pinniped.dev/internal/fosite/authorizationcode" + "go.pinniped.dev/internal/fositestorage/authorizationcode" "go.pinniped.dev/test/library" ) @@ -29,7 +29,7 @@ func TestAuthorizeCodeStorage(t *testing.T) { // randomly generated HMAC authorization code (see below) code = "TQ72B8YjdEOZyxridYbTLE-pzoK4hpdkZxym5j4EmSc.TKRTgQG41IBQ16FDKTthRdhXfLlNaErcMd9Fy47uXAw" // name of the secret that will be created in Kube - name = "pinniped-storage-authorization-codes-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga" + name = "pinniped-storage-authcode-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga" ) hmac := compose.NewOAuth2HMACStrategy(&compose.Config{}, []byte("super-secret-32-byte-for-testing"), nil) diff --git a/test/integration/supervisor_discovery_test.go b/test/integration/supervisor_discovery_test.go index 396c0f48..32945490 100644 --- a/test/integration/supervisor_discovery_test.go +++ b/test/integration/supervisor_discovery_test.go @@ -111,7 +111,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) { // When the same issuer is added twice, both issuers are marked as duplicates, and neither provider is serving. config6Duplicate1, _ := requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(ctx, t, scheme, addr, caBundle, issuer6, client) - config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "") + config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "", "") requireStatus(t, client, ns, config6Duplicate1.Name, v1alpha1.DuplicateOIDCProviderStatusCondition) requireStatus(t, client, ns, config6Duplicate2.Name, v1alpha1.DuplicateOIDCProviderStatusCondition) requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, issuer6) @@ -136,7 +136,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) { } // When we create a provider with an invalid issuer, the status is set to invalid. - badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "") + badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "", "") requireStatus(t, client, ns, badConfig.Name, v1alpha1.InvalidOIDCProviderStatusCondition) requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, badIssuer) requireDeletingOIDCProviderCausesDiscoveryEndpointsToDisappear(t, badConfig, client, ns, scheme, addr, caBundle, badIssuer) @@ -162,7 +162,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) { certSecretName1 := "integration-test-cert-1" // Create an OIDCProvider with a spec.tls.secretName. - oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1) + oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1, "") requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // The spec.tls.secretName Secret does not exist, so the endpoints should fail with TLS errors. @@ -198,7 +198,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) { certSecretName2 := "integration-test-cert-2" // Create an OIDCProvider with a spec.tls.secretName. - oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2) + oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2, "") requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // Create the Secret. @@ -232,31 +232,30 @@ func TestSupervisorTLSTerminationWithDefaultCerts(t *testing.T) { port = hostAndPortSegments[1] } - ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname) + ips, err := library.LookupIP(ctx, hostname) require.NoError(t, err) - ip := ips[0] - ipAsString := ip.String() - ipWithPort := ipAsString + ":" + port + require.NotEmpty(t, ips) + ipWithPort := ips[0].String() + ":" + port issuerUsingIPAddress := fmt.Sprintf("%s://%s/issuer1", scheme, ipWithPort) issuerUsingHostname := fmt.Sprintf("%s://%s/issuer1", scheme, address) // Create an OIDCProvider without a spec.tls.secretName. - oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "") + oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "", "") requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // There is no default TLS cert and the spec.tls.secretName was not set, so the endpoints should fail with TLS errors. requireEndpointHasTLSErrorBecauseCertificatesAreNotReady(t, issuerUsingIPAddress) // Create a Secret at the special name which represents the default TLS cert. - defaultCA := createTLSCertificateSecret(ctx, t, ns, "cert-hostname-doesnt-matter", []net.IP{ip.IP}, defaultTLSCertSecretName(env), kubeClient) + defaultCA := createTLSCertificateSecret(ctx, t, ns, "cert-hostname-doesnt-matter", []net.IP{ips[0]}, defaultTLSCertSecretName(env), kubeClient) // Now that the Secret exists, we should be able to access the endpoints by IP address using the CA. _ = requireDiscoveryEndpointsAreWorking(t, scheme, ipWithPort, string(defaultCA.Bundle()), issuerUsingIPAddress, nil) // Create an OIDCProvider with a spec.tls.secretName. certSecretName := "integration-test-cert-1" - oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName) + oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName, "") requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition) // Create the Secret. @@ -429,7 +428,7 @@ func requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear( client pinnipedclientset.Interface, ) (*v1alpha1.OIDCProvider, *ExpectedJWKSResponseFormat) { t.Helper() - newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "") + newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "", "") jwksResult := requireDiscoveryEndpointsAreWorking(t, supervisorScheme, supervisorAddress, supervisorCABundle, issuerName, nil) requireStatus(t, client, newOIDCProvider.Namespace, newOIDCProvider.Name, v1alpha1.SuccessOIDCProviderStatusCondition) return newOIDCProvider, jwksResult diff --git a/test/integration/supervisor_keys_test.go b/test/integration/supervisor_keys_test.go index 17e6a580..d59c713e 100644 --- a/test/integration/supervisor_keys_test.go +++ b/test/integration/supervisor_keys_test.go @@ -27,7 +27,7 @@ func TestSupervisorOIDCKeys(t *testing.T) { defer cancel() // Create our OPC under test. - opc := library.CreateTestOIDCProvider(ctx, t, "", "") + opc := library.CreateTestOIDCProvider(ctx, t, "", "", "") // Ensure a secret is created with the OPC's JWKS. var updatedOPC *configv1alpha1.OIDCProvider diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index ca5c2787..5df74651 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -6,236 +6,180 @@ package integration import ( "context" "crypto/tls" - "crypto/x509" + "crypto/x509/pkix" "encoding/base64" - "fmt" "net/http" + "net/http/httptest" "net/url" - "path" + "regexp" "strings" "testing" "time" "github.com/coreos/go-oidc" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" + configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1" idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" + "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/state" "go.pinniped.dev/test/library" + "go.pinniped.dev/test/library/browsertest" ) func TestSupervisorLogin(t *testing.T) { - t.Skip("waiting on new callback path logic to get merged in from the callback endpoint work") - env := library.IntegrationEnv(t) - client := library.NewSupervisorClientset(t) + + // If anything in this test crashes, dump out the supervisor pod logs. + defer library.DumpLogs(t, env.SupervisorNamespace) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() - tests := []struct { - Scheme string - Address string - CABundle string - }{ - {Scheme: "http", Address: env.SupervisorHTTPAddress}, - {Scheme: "https", Address: env.SupervisorHTTPSIngressAddress, CABundle: env.SupervisorHTTPSIngressCABundle}, - } + // Infer the downstream issuer URL from the callback associated with the upstream test client registration. + issuerURL, err := url.Parse(env.SupervisorTestUpstream.CallbackURL) + require.NoError(t, err) + require.True(t, strings.HasSuffix(issuerURL.Path, "/callback")) + issuerURL.Path = strings.TrimSuffix(issuerURL.Path, "/callback") + t.Logf("testing with downstream issuer URL %s", issuerURL.String()) - for _, test := range tests { - scheme := test.Scheme - addr := test.Address - caBundle := test.CABundle - - if addr == "" { - // Both cases are not required, so when one is empty skip it. - continue - } - - // Create downstream OIDC provider (i.e., update supervisor with OIDC provider). - path := getDownstreamIssuerPathFromUpstreamRedirectURI(t, env.SupervisorTestUpstream.CallbackURL) - issuer := fmt.Sprintf("https://%s%s", addr, path) - _, _ = requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear( - ctx, - t, - scheme, - addr, - caBundle, - issuer, - client, - ) - - // Create HTTP client. - httpClient := newHTTPClient(t, caBundle, nil) - httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { - // Don't follow any redirects right now, since we simply want to validate that our auth endpoint - // redirects us. - return http.ErrUseLastResponse - } - - // Declare the downstream auth endpoint url we will use. - downstreamAuthURL := makeDownstreamAuthURL(t, scheme, addr, path) - - // Make request to auth endpoint - should fail, since we have no upstreams. - req, err := http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil) - require.NoError(t, err) - rsp, err := httpClient.Do(req) - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusUnprocessableEntity, rsp.StatusCode) - - // Create upstream OIDC provider. - spec := idpv1alpha1.UpstreamOIDCProviderSpec{ - Issuer: env.SupervisorTestUpstream.Issuer, - TLS: &idpv1alpha1.TLSSpec{ - CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), - }, - Client: idpv1alpha1.OIDCClient{ - SecretName: makeTestClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, - }, - } - upstream := makeTestUpstream(t, spec, idpv1alpha1.PhaseReady) - - // Make request to authorize endpoint - should pass, since we now have an upstream. - req, err = http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil) - require.NoError(t, err) - rsp, err = httpClient.Do(req) - require.NoError(t, err) - defer rsp.Body.Close() - require.Equal(t, http.StatusFound, rsp.StatusCode) - requireValidRedirectLocation( - ctx, - t, - upstream.Spec.Issuer, - env.SupervisorTestUpstream.ClientID, - env.SupervisorTestUpstream.CallbackURL, - rsp.Header.Get("Location"), - ) - } -} - -//nolint:unused -func getDownstreamIssuerPathFromUpstreamRedirectURI(t *testing.T, upstreamRedirectURI string) string { - // We need to construct the downstream issuer path from the upstream redirect URI since the two - // are related, and the upstream redirect URI is supplied via a static test environment - // variable. The upstream redirect URI should be something like - // https://supervisor.com/some/supervisor/path/callback - // and therefore the downstream issuer should be something like - // https://supervisor.com/some/supervisor/path - // since the /callback endpoint is placed at the root of the downstream issuer path. - upstreamRedirectURL, err := url.Parse(upstreamRedirectURI) + // Generate a CA bundle with which to serve this provider. + t.Logf("generating test CA") + ca, err := certauthority.New(pkix.Name{CommonName: "Downstream Test CA"}, 1*time.Hour) require.NoError(t, err) - redirectURIPathWithoutLastSegment, lastUpstreamRedirectURIPathSegment := path.Split(upstreamRedirectURL.Path) - require.Equalf( - t, - "callback", - lastUpstreamRedirectURIPathSegment, - "expected upstream redirect URI (%q) to follow supervisor callback path conventions (i.e., end in /callback)", - upstreamRedirectURI, + // Create an HTTP client that can reach the downstream discovery endpoint using the CA certs. + httpClient := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: ca.Pool()}, + Proxy: func(req *http.Request) (*url.URL, error) { + if env.Proxy == "" { + return nil, nil + } + return url.Parse(env.Proxy) + }, + }} + + // Use the CA to issue a TLS server cert. + t.Logf("issuing test certificate") + tlsCert, err := ca.Issue( + pkix.Name{CommonName: issuerURL.Hostname()}, + []string{issuerURL.Hostname()}, + nil, + 1*time.Hour, + ) + require.NoError(t, err) + certPEM, keyPEM, err := certauthority.ToPEM(tlsCert) + require.NoError(t, err) + + // Write the serving cert to a secret. + certSecret := library.CreateTestSecret(t, + env.SupervisorNamespace, + "oidc-provider-tls", + "kubernetes.io/tls", + map[string]string{"tls.crt": string(certPEM), "tls.key": string(keyPEM)}, ) - if strings.HasSuffix(redirectURIPathWithoutLastSegment, "/") { - redirectURIPathWithoutLastSegment = redirectURIPathWithoutLastSegment[:len(redirectURIPathWithoutLastSegment)-1] - } + // Create the downstream OIDCProvider and expect it to go into the success status condition. + downstream := library.CreateTestOIDCProvider(ctx, t, + issuerURL.String(), + certSecret.Name, + configv1alpha1.SuccessOIDCProviderStatusCondition, + ) - return redirectURIPathWithoutLastSegment -} + // Create upstream OIDC provider and wait for it to become ready. + library.CreateTestUpstreamOIDCProvider(t, idpv1alpha1.UpstreamOIDCProviderSpec{ + Issuer: env.SupervisorTestUpstream.Issuer, + TLS: &idpv1alpha1.TLSSpec{ + CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), + }, + Client: idpv1alpha1.OIDCClient{ + SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, + }, + }, idpv1alpha1.PhaseReady) -//nolint:unused -func makeDownstreamAuthURL(t *testing.T, scheme, addr, path string) string { - t.Helper() + // Perform OIDC discovery for our downstream. + var discovery *oidc.Provider + assert.Eventually(t, func() bool { + discovery, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer) + return err == nil + }, 60*time.Second, 1*time.Second) + require.NoError(t, err) + + // Start a callback server on localhost. + localCallbackServer := startLocalCallbackServer(t) + + // Form the OAuth2 configuration corresponding to our CLI client. downstreamOAuth2Config := oauth2.Config{ // This is the hardcoded public client that the supervisor supports. - ClientID: "pinniped-cli", - Endpoint: oauth2.Endpoint{ - AuthURL: fmt.Sprintf("%s://%s%s/oauth2/authorize", scheme, addr, path), - }, - // This is the hardcoded downstream redirect URI that the supervisor supports. - RedirectURL: "http://127.0.0.1/callback", + ClientID: "pinniped-cli", + Endpoint: discovery.Endpoint(), + RedirectURL: localCallbackServer.URL, Scopes: []string{"openid"}, } - state, nonce, pkce := generateAuthRequestParams(t) - return downstreamOAuth2Config.AuthCodeURL( - state.String(), - nonce.Param(), - pkce.Challenge(), - pkce.Method(), + + // Build a valid downstream authorize URL for the supervisor. + stateParam, err := state.Generate() + require.NoError(t, err) + nonceParam, err := nonce.Generate() + require.NoError(t, err) + pkceParam, err := pkce.Generate() + require.NoError(t, err) + downstreamAuthorizeURL := downstreamOAuth2Config.AuthCodeURL( + stateParam.String(), + nonceParam.Param(), + pkceParam.Challenge(), + pkceParam.Method(), ) + + // Open the web browser and navigate to the downstream authorize URL. + page := browsertest.Open(t) + t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL)) + require.NoError(t, page.Navigate(downstreamAuthorizeURL)) + + // Expect to be redirected to the upstream provider and log in. + browsertest.LoginToUpstream(t, page, env.SupervisorTestUpstream) + + // Wait for the login to happen and us be redirected back to a localhost callback. + t.Logf("waiting for redirect to callback") + callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(localCallbackServer.URL) + `\?.+\z`) + browsertest.WaitForURL(t, page, callbackURLPattern) + + // Expect that our callback handler was invoked. + callback := localCallbackServer.waitForCallback(10 * time.Second) + t.Logf("got callback request: %s", library.MaskTokens(callback.URL.String())) + require.Equal(t, stateParam.String(), callback.URL.Query().Get("state")) + require.Equal(t, "openid", callback.URL.Query().Get("scope")) + require.NotEmpty(t, callback.URL.Query().Get("code")) } -//nolint:unused -func generateAuthRequestParams(t *testing.T) (state.State, nonce.Nonce, pkce.Code) { - t.Helper() - state, err := state.Generate() - require.NoError(t, err) - nonce, err := nonce.Generate() - require.NoError(t, err) - pkce, err := pkce.Generate() - require.NoError(t, err) - return state, nonce, pkce +func startLocalCallbackServer(t *testing.T) *localCallbackServer { + // Handle the callback by sending the *http.Request object back through a channel. + callbacks := make(chan *http.Request, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callbacks <- r + })) + server.URL += "/callback" + t.Cleanup(server.Close) + t.Cleanup(func() { close(callbacks) }) + return &localCallbackServer{Server: server, t: t, callbacks: callbacks} } -//nolint:unused -func requireValidRedirectLocation( - ctx context.Context, - t *testing.T, - issuer, clientID, redirectURI, actualLocation string, -) { - t.Helper() - env := library.IntegrationEnv(t) +type localCallbackServer struct { + *httptest.Server + t *testing.T + callbacks <-chan *http.Request +} - // Do OIDC discovery on our test issuer to get auth endpoint. - transport := http.Transport{} - if env.Proxy != "" { - transport.Proxy = func(_ *http.Request) (*url.URL, error) { - return url.Parse(env.Proxy) - } +func (s *localCallbackServer) waitForCallback(timeout time.Duration) *http.Request { + select { + case callback := <-s.callbacks: + return callback + case <-time.After(timeout): + require.Fail(s.t, "timed out waiting for callback request") + return nil } - if env.SupervisorTestUpstream.CABundle != "" { - transport.TLSClientConfig = &tls.Config{RootCAs: x509.NewCertPool()} - transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(env.SupervisorTestUpstream.CABundle)) - } - - ctx = oidc.ClientContext(ctx, &http.Client{Transport: &transport}) - upstreamProvider, err := oidc.NewProvider(ctx, issuer) - require.NoError(t, err) - - // Parse expected upstream auth URL. - expectedLocationURL, err := url.Parse( - (&oauth2.Config{ - ClientID: clientID, - Endpoint: upstreamProvider.Endpoint(), - RedirectURL: redirectURI, - Scopes: []string{"openid"}, - }).AuthCodeURL("", oauth2.AccessTypeOffline), - ) - require.NoError(t, err) - - // Parse actual upstream auth URL. - actualLocationURL, err := url.Parse(actualLocation) - require.NoError(t, err) - - // First make some assertions on the query values. Note that we will not be able to know what - // certain query values are since they may be random (e.g., state, pkce, nonce). - expectedLocationQuery := expectedLocationURL.Query() - actualLocationQuery := actualLocationURL.Query() - require.NotEmpty(t, actualLocationQuery.Get("state")) - actualLocationQuery.Del("state") - require.NotEmpty(t, actualLocationQuery.Get("code_challenge")) - actualLocationQuery.Del("code_challenge") - require.NotEmpty(t, actualLocationQuery.Get("code_challenge_method")) - actualLocationQuery.Del("code_challenge_method") - require.NotEmpty(t, actualLocationQuery.Get("nonce")) - actualLocationQuery.Del("nonce") - require.Equal(t, expectedLocationQuery, actualLocationQuery) - - // Zero-out query values, since we made specific assertions about those above, and assert that the - // URL's are equal otherwise. - expectedLocationURL.RawQuery = "" - actualLocationURL.RawQuery = "" - require.Equal(t, expectedLocationURL, actualLocationURL) } diff --git a/test/integration/supervisor_upstream_test.go b/test/integration/supervisor_upstream_test.go index e38f2b17..dd3fa528 100644 --- a/test/integration/supervisor_upstream_test.go +++ b/test/integration/supervisor_upstream_test.go @@ -4,13 +4,10 @@ package integration import ( - "context" "encoding/base64" "testing" - "time" "github.com/stretchr/testify/require" - corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" @@ -28,7 +25,7 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) { SecretName: "does-not-exist", }, } - upstream := makeTestUpstream(t, spec, v1alpha1.PhaseError) + upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseError) expectUpstreamConditions(t, upstream, []v1alpha1.Condition{ { Type: "ClientCredentialsValid", @@ -56,10 +53,10 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) { AdditionalScopes: []string{"email", "profile"}, }, Client: v1alpha1.OIDCClient{ - SecretName: makeTestClientCredsSecret(t, "test-client-id", "test-client-secret").Name, + SecretName: library.CreateClientCredsSecret(t, "test-client-id", "test-client-secret").Name, }, } - upstream := makeTestUpstream(t, spec, v1alpha1.PhaseReady) + upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseReady) expectUpstreamConditions(t, upstream, []v1alpha1.Condition{ { Type: "ClientCredentialsValid", @@ -87,74 +84,3 @@ func expectUpstreamConditions(t *testing.T, upstream *v1alpha1.UpstreamOIDCProvi } require.ElementsMatch(t, expected, normalized) } - -func makeTestClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret { - t.Helper() - env := library.IntegrationEnv(t) - client := library.NewClientset(t) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - created, err := client.CoreV1().Secrets(env.SupervisorNamespace).Create(ctx, &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: env.SupervisorNamespace, - GenerateName: "test-client-creds-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, - Type: "secrets.pinniped.dev/oidc-client", - StringData: map[string]string{ - "clientID": clientID, - "clientSecret": clientSecret, - }, - }, metav1.CreateOptions{}) - require.NoError(t, err) - t.Cleanup(func() { - err := client.CoreV1().Secrets(env.SupervisorNamespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{}) - require.NoError(t, err) - }) - t.Logf("created test client credentials Secret %s", created.Name) - return created -} - -func makeTestUpstream(t *testing.T, spec v1alpha1.UpstreamOIDCProviderSpec, expectedPhase v1alpha1.UpstreamOIDCProviderPhase) *v1alpha1.UpstreamOIDCProvider { - t.Helper() - env := library.IntegrationEnv(t) - client := library.NewSupervisorClientset(t) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - // Create the UpstreamOIDCProvider using GenerateName to get a random name. - created, err := client.IDPV1alpha1(). - UpstreamOIDCProviders(env.SupervisorNamespace). - Create(ctx, &v1alpha1.UpstreamOIDCProvider{ - ObjectMeta: metav1.ObjectMeta{ - Namespace: env.SupervisorNamespace, - GenerateName: "test-upstream-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, - Spec: spec, - }, metav1.CreateOptions{}) - require.NoError(t, err) - - // Always clean this up after this point. - t.Cleanup(func() { - err := client.IDPV1alpha1(). - UpstreamOIDCProviders(env.SupervisorNamespace). - Delete(context.Background(), created.Name, metav1.DeleteOptions{}) - require.NoError(t, err) - }) - t.Logf("created test UpstreamOIDCProvider %s", created.Name) - - // Wait for the UpstreamOIDCProvider to enter the expected phase (or time out). - var result *v1alpha1.UpstreamOIDCProvider - require.Eventuallyf(t, func() bool { - var err error - result, err = client.IDPV1alpha1(). - UpstreamOIDCProviders(created.Namespace).Get(ctx, created.Name, metav1.GetOptions{}) - require.NoError(t, err) - return result.Status.Phase == expectedPhase - }, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase) - return result -} diff --git a/test/library/browsertest/browsertest.go b/test/library/browsertest/browsertest.go new file mode 100644 index 00000000..d7da8142 --- /dev/null +++ b/test/library/browsertest/browsertest.go @@ -0,0 +1,158 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package browsertest provides integration test helpers for our browser-based tests. +package browsertest + +import ( + "regexp" + "testing" + "time" + + "github.com/sclevine/agouti" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/test/library" +) + +const ( + operationTimeout = 10 * time.Second + operationPollingInterval = 100 * time.Millisecond +) + +// Open a webdriver-driven browser and returns an *agouti.Page to control it. The browser will be automatically +// closed at the end of the current test. It is configured for test purposes with the correct HTTP proxy and +// in a mode that ignore certificate errors. +func Open(t *testing.T) *agouti.Page { + t.Logf("opening browser driver") + env := library.IntegrationEnv(t) + caps := agouti.NewCapabilities() + if env.Proxy != "" { + t.Logf("configuring Chrome to use proxy %q", env.Proxy) + caps = caps.Proxy(agouti.ProxyConfig{ + ProxyType: "manual", + HTTPProxy: env.Proxy, + SSLProxy: env.Proxy, + NoProxy: "127.0.0.1", + }) + } + agoutiDriver := agouti.ChromeDriver( + agouti.Desired(caps), + agouti.ChromeOptions("args", []string{ + "--no-sandbox", + "--ignore-certificate-errors", + "--headless", // Comment out this line to see the tests happen in a visible browser window. + }), + // Uncomment this to see stdout/stderr from chromedriver. + // agouti.Debug, + ) + require.NoError(t, agoutiDriver.Start()) + t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) }) + page, err := agoutiDriver.NewPage(agouti.Browser("chrome")) + require.NoError(t, err) + require.NoError(t, page.Reset()) + return page +} + +// WaitForVisibleElements expects the page to contain all the the elements specified by the selectors. It waits for this +// to occur and times out, failing the test, if they never appear. +func WaitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) { + t.Helper() + + require.Eventuallyf(t, + func() bool { + for _, sel := range selectors { + vis, err := page.First(sel).Visible() + if !(err == nil && vis) { + return false + } + } + return true + }, + operationTimeout, + operationPollingInterval, + "expected to have a page with selectors %v, but it never loaded", + selectors, + ) +} + +// WaitForURL expects the page to eventually navigate to a URL matching the specified pattern. It waits for this +// to occur and times out, failing the test, if it never does. +func WaitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) { + var lastURL string + require.Eventuallyf(t, + func() bool { + url, err := page.URL() + if err == nil && pat.MatchString(url) { + return true + } + if url != lastURL { + t.Logf("saw URL %s", url) + lastURL = url + } + return false + }, + operationTimeout, + operationPollingInterval, + "expected to browse to %s, but never got there", + pat, + ) +} + +// LoginToUpstream expects the page to be redirected to one of several known upstream IDPs. +// It knows how to enter the test username/password and submit the upstream login form. +func LoginToUpstream(t *testing.T, page *agouti.Page, upstream library.TestOIDCUpstream) { + t.Helper() + + type config struct { + Name string + IssuerPattern *regexp.Regexp + LoginPagePattern *regexp.Regexp + UsernameSelector string + PasswordSelector string + LoginButtonSelector string + } + + // Lookup the provider by matching on the issuer URL. + var cfg *config + for _, p := range []*config{ + { + Name: "Okta", + IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), + LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`), + UsernameSelector: "input#okta-signin-username", + PasswordSelector: "input#okta-signin-password", + LoginButtonSelector: "input#okta-signin-submit", + }, + { + Name: "Dex", + IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`), + LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`), + UsernameSelector: "input#login", + PasswordSelector: "input#password", + LoginButtonSelector: "button#submit-login", + }, + } { + if p.IssuerPattern.MatchString(upstream.Issuer) { + cfg = p + break + } + } + if cfg == nil { + require.Failf(t, "could not find login provider for issuer %q", upstream.Issuer) + return + } + + // Expect to be redirected to the login page. + t.Logf("waiting for redirect to %s login page", cfg.Name) + WaitForURL(t, page, cfg.LoginPagePattern) + + // Wait for the login page to be rendered. + WaitForVisibleElements(t, page, cfg.UsernameSelector, cfg.PasswordSelector, cfg.LoginButtonSelector) + + // Fill in the username and password and click "submit". + t.Logf("logging into %s", cfg.Name) + require.NoError(t, page.First(cfg.UsernameSelector).Fill(upstream.Username)) + require.NoError(t, page.First(cfg.PasswordSelector).Fill(upstream.Password)) + require.NoError(t, page.First(cfg.LoginButtonSelector).Click()) +} diff --git a/test/library/client.go b/test/library/client.go index f5aba3a6..d95f0426 100644 --- a/test/library/client.go +++ b/test/library/client.go @@ -25,6 +25,7 @@ import ( auth1alpha1 "go.pinniped.dev/generated/1.19/apis/concierge/authentication/v1alpha1" configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1" + idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" conciergeclientset "go.pinniped.dev/generated/1.19/client/concierge/clientset/versioned" supervisorclientset "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned" @@ -140,12 +141,8 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty defer cancel() webhook, err := webhooks.Create(createContext, &auth1alpha1.WebhookAuthenticator{ - ObjectMeta: metav1.ObjectMeta{ - GenerateName: "test-webhook-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, - Spec: testEnv.TestWebhook, + ObjectMeta: testObjectMeta(t, "webhook"), + Spec: testEnv.TestWebhook, }, metav1.CreateOptions{}) require.NoError(t, err, "could not create test WebhookAuthenticator") t.Logf("created test WebhookAuthenticator %s/%s", webhook.Namespace, webhook.Name) @@ -172,7 +169,7 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty // // If the provided issuer is not the empty string, then it will be used for the // OIDCProvider.Spec.Issuer field. Else, a random issuer will be generated. -func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecretName string) *configv1alpha1.OIDCProvider { +func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer string, certSecretName string, expectStatus configv1alpha1.OIDCProviderStatusCondition) *configv1alpha1.OIDCProvider { t.Helper() testEnv := IntegrationEnv(t) @@ -180,18 +177,12 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre defer cancel() if issuer == "" { - var err error - issuer, err = randomIssuer() - require.NoError(t, err) + issuer = randomIssuer(t) } opcs := NewSupervisorClientset(t).ConfigV1alpha1().OIDCProviders(testEnv.SupervisorNamespace) opc, err := opcs.Create(createContext, &configv1alpha1.OIDCProvider{ - ObjectMeta: metav1.ObjectMeta{ - GenerateName: "test-oidc-provider-", - Labels: map[string]string{"pinniped.dev/test": ""}, - Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, - }, + ObjectMeta: testObjectMeta(t, "oidc-provider"), Spec: configv1alpha1.OIDCProviderSpec{ Issuer: issuer, TLS: &configv1alpha1.OIDCProviderTLSSpec{SecretName: certSecretName}, @@ -213,13 +204,103 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre } }) + // If we're not expecting any particular status, just return the new OIDCProvider immediately. + if expectStatus == "" { + return opc + } + + // Wait for the OIDCProvider to enter the expected phase (or time out). + var result *configv1alpha1.OIDCProvider + require.Eventuallyf(t, func() bool { + var err error + result, err = opcs.Get(ctx, opc.Name, metav1.GetOptions{}) + require.NoError(t, err) + return result.Status.Status == expectStatus + }, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectStatus) + return opc } -func randomIssuer() (string, error) { +func randomIssuer(t *testing.T) string { var buf [8]byte - if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil { - return "", fmt.Errorf("could not generate random state: %w", err) - } - return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:])), nil + _, err := io.ReadFull(rand.Reader, buf[:]) + require.NoError(t, err) + return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:])) +} + +func CreateTestSecret(t *testing.T, namespace string, baseName string, secretType string, stringData map[string]string) *corev1.Secret { + t.Helper() + client := NewClientset(t) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + created, err := client.CoreV1().Secrets(namespace).Create(ctx, &corev1.Secret{ + ObjectMeta: testObjectMeta(t, baseName), + Type: corev1.SecretType(secretType), + StringData: stringData, + }, metav1.CreateOptions{}) + require.NoError(t, err) + + t.Cleanup(func() { + err := client.CoreV1().Secrets(namespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{}) + require.NoError(t, err) + }) + t.Logf("created test Secret %s", created.Name) + return created +} + +func CreateClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret { + t.Helper() + env := IntegrationEnv(t) + return CreateTestSecret(t, + env.SupervisorNamespace, + "test-client-creds-", + "secrets.pinniped.dev/oidc-client", + map[string]string{ + "clientID": clientID, + "clientSecret": clientSecret, + }, + ) +} + +func CreateTestUpstreamOIDCProvider(t *testing.T, spec idpv1alpha1.UpstreamOIDCProviderSpec, expectedPhase idpv1alpha1.UpstreamOIDCProviderPhase) *idpv1alpha1.UpstreamOIDCProvider { + t.Helper() + env := IntegrationEnv(t) + client := NewSupervisorClientset(t) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + // Create the UpstreamOIDCProvider using GenerateName to get a random name. + upstreams := client.IDPV1alpha1().UpstreamOIDCProviders(env.SupervisorNamespace) + + created, err := upstreams.Create(ctx, &idpv1alpha1.UpstreamOIDCProvider{ + ObjectMeta: testObjectMeta(t, "upstream"), + Spec: spec, + }, metav1.CreateOptions{}) + require.NoError(t, err) + + // Always clean this up after this point. + t.Cleanup(func() { + err := upstreams.Delete(context.Background(), created.Name, metav1.DeleteOptions{}) + require.NoError(t, err) + }) + t.Logf("created test UpstreamOIDCProvider %s", created.Name) + + // Wait for the UpstreamOIDCProvider to enter the expected phase (or time out). + var result *idpv1alpha1.UpstreamOIDCProvider + require.Eventuallyf(t, func() bool { + var err error + result, err = upstreams.Get(ctx, created.Name, metav1.GetOptions{}) + require.NoError(t, err) + return result.Status.Phase == expectedPhase + }, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase) + return result +} + +func testObjectMeta(t *testing.T, baseName string) metav1.ObjectMeta { + return metav1.ObjectMeta{ + GenerateName: fmt.Sprintf("test-%s-", baseName), + Labels: map[string]string{"pinniped.dev/test": ""}, + Annotations: map[string]string{"pinniped.dev/testName": t.Name()}, + } } diff --git a/test/library/dumplogs.go b/test/library/dumplogs.go new file mode 100644 index 00000000..33f33694 --- /dev/null +++ b/test/library/dumplogs.go @@ -0,0 +1,49 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package library + +import ( + "bufio" + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// DumpLogs is meant to be called in a `defer` to dump the logs of components in the cluster on a test failure. +func DumpLogs(t *testing.T, namespace string) { + // Only trigger on failed tests. + if !t.Failed() { + return + } + + kubeClient := NewClientset(t) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + logTailLines := int64(40) + pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{}) + require.NoError(t, err) + + for _, pod := range pods.Items { + for _, container := range pod.Status.ContainerStatuses { + t.Logf("pod %s/%s container %s restarted %d times:", pod.Namespace, pod.Name, container.Name, container.RestartCount) + req := kubeClient.CoreV1().Pods(namespace).GetLogs(pod.Name, &corev1.PodLogOptions{ + Container: container.Name, + TailLines: &logTailLines, + }) + logReader, err := req.Stream(ctx) + require.NoError(t, err) + + scanner := bufio.NewScanner(logReader) + for scanner.Scan() { + t.Logf("%s/%s/%s > %s", pod.Namespace, pod.Name, container.Name, scanner.Text()) + } + require.NoError(t, scanner.Err()) + } + } +} diff --git a/test/library/iotest.go b/test/library/iotest.go index dcb0e695..6e2f1e58 100644 --- a/test/library/iotest.go +++ b/test/library/iotest.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "regexp" + "strings" "testing" ) @@ -26,18 +27,22 @@ func (l *testlogReader) Read(p []byte) (n int, err error) { l.t.Helper() n, err = l.r.Read(p) if err != nil { - l.t.Logf("%s > %q: %v", l.name, maskTokens(p[0:n]), err) + l.t.Logf("%s > %q: %v", l.name, MaskTokens(string(p[0:n])), err) } else { - l.t.Logf("%s > %q", l.name, maskTokens(p[0:n])) + l.t.Logf("%s > %q", l.name, MaskTokens(string(p[0:n]))) } return } -//nolint: gochecknoglobals -var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) - -func maskTokens(in []byte) string { - return tokenLike.ReplaceAllStringFunc(string(in), func(t string) string { +// MaskTokens makes a best-effort attempt to mask out things that look like secret tokens in test output. +// The goal is more to have readable test output than for any security reason. +func MaskTokens(in string) string { + var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) + return tokenLike.ReplaceAllStringFunc(in, func(t string) string { + // This is a silly heuristic, but things with multiple dots are more likely hostnames that we don't want masked. + if strings.Count(t, ".") >= 4 { + return t + } return fmt.Sprintf("[...%d bytes...]", len(t)) }) } diff --git a/test/library/iplookup.go b/test/library/iplookup.go new file mode 100644 index 00000000..c3ce6cc3 --- /dev/null +++ b/test/library/iplookup.go @@ -0,0 +1,16 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// +build !go1.14 + +package library + +import ( + "context" + "net" +) + +// LookupIP looks up the IP address of the provided hostname, preferring IPv4. +func LookupIP(ctx context.Context, hostname string) ([]net.IP, error) { + return net.DefaultResolver.LookupIP(ctx, "ip4", hostname) +} diff --git a/test/library/iplookup_go1.14.go b/test/library/iplookup_go1.14.go new file mode 100644 index 00000000..c4a4a7e6 --- /dev/null +++ b/test/library/iplookup_go1.14.go @@ -0,0 +1,28 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// +build go1.14 + +package library + +import ( + "context" + "net" +) + +// LookupIP looks up the IP address of the provided hostname, preferring IPv4. +func LookupIP(ctx context.Context, hostname string) ([]net.IP, error) { + ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname) + if err != nil { + return nil, err + } + + // Filter out to only IPv4 addresses + var results []net.IP + for _, ip := range ips { + if ip.IP.To4() != nil { + results = append(results, ip.IP) + } + } + return results, nil +}