Merge remote-tracking branch 'upstream/callback-endpoint' into token-endpoint

This commit is contained in:
Andrew Keesler 2020-12-03 11:14:37 -05:00
commit 2f1a67ef0d
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
56 changed files with 4396 additions and 1545 deletions

View File

@ -196,7 +196,12 @@ func run(serverInstallationNamespace string, cfg *supervisor.Config) error {
dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider() dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider()
// OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux. // OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux.
oidProvidersManager := manager.NewManager(healthMux, dynamicJWKSProvider, dynamicUpstreamIDPProvider) oidProvidersManager := manager.NewManager(
healthMux,
dynamicJWKSProvider,
dynamicUpstreamIDPProvider,
kubeClient.CoreV1().Secrets(serverInstallationNamespace),
)
startControllers( startControllers(
ctx, ctx,

View File

@ -20,6 +20,7 @@ import (
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/filesession" "go.pinniped.dev/pkg/oidcclient/filesession"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
) )
//nolint: gochecknoinits //nolint: gochecknoinits
@ -27,7 +28,7 @@ func init() {
loginCmd.AddCommand(oidcLoginCommand(oidcclient.Login)) loginCmd.AddCommand(oidcLoginCommand(oidcclient.Login))
} }
func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error)) *cobra.Command { func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error)) *cobra.Command {
var ( var (
cmd = cobra.Command{ cmd = cobra.Command{
Args: cobra.NoArgs, Args: cobra.NoArgs,

View File

@ -13,6 +13,7 @@ import (
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
) )
func TestLoginOIDCCommand(t *testing.T) { func TestLoginOIDCCommand(t *testing.T) {
@ -92,12 +93,12 @@ func TestLoginOIDCCommand(t *testing.T) {
gotClientID string gotClientID string
gotOptions []oidcclient.Option gotOptions []oidcclient.Option
) )
cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error) { cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error) {
gotIssuer = issuer gotIssuer = issuer
gotClientID = clientID gotClientID = clientID
gotOptions = opts gotOptions = opts
return &oidcclient.Token{ return &oidctypes.Token{
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(time1), Expiry: metav1.NewTime(time1),
}, },

1
go.mod
View File

@ -18,6 +18,7 @@ require (
github.com/gorilla/securecookie v1.1.1 github.com/gorilla/securecookie v1.1.1
github.com/ory/fosite v0.35.1 github.com/ory/fosite v0.35.1
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/pkg/errors v0.9.1
github.com/sclevine/agouti v3.0.0+incompatible github.com/sclevine/agouti v3.0.0+incompatible
github.com/sclevine/spec v1.4.0 github.com/sclevine/spec v1.4.0
github.com/spf13/cobra v1.0.0 github.com/spf13/cobra v1.0.0

View File

@ -103,6 +103,7 @@ k8s_yaml(local([
'--data-value-yaml', 'service_http_nodeport_nodeport=31234', '--data-value-yaml', 'service_http_nodeport_nodeport=31234',
'--data-value-yaml', 'service_https_nodeport_port=443', '--data-value-yaml', 'service_https_nodeport_port=443',
'--data-value-yaml', 'service_https_nodeport_nodeport=31243', '--data-value-yaml', 'service_https_nodeport_nodeport=31243',
'--data-value-yaml', 'service_https_clusterip_port=443',
'--data-value-yaml', 'custom_labels={mySupervisorCustomLabelName: mySupervisorCustomLabelValue}', '--data-value-yaml', 'custom_labels={mySupervisorCustomLabelName: mySupervisorCustomLabelValue}',
])) ]))
# Tell tilt to watch all of those files for changes. # Tell tilt to watch all of those files for changes.

View File

@ -230,6 +230,7 @@ if ! tilt_mode; then
--data-value-yaml 'service_http_nodeport_nodeport=31234' \ --data-value-yaml 'service_http_nodeport_nodeport=31234' \
--data-value-yaml 'service_https_nodeport_port=443' \ --data-value-yaml 'service_https_nodeport_port=443' \
--data-value-yaml 'service_https_nodeport_nodeport=31243' \ --data-value-yaml 'service_https_nodeport_nodeport=31243' \
--data-value-yaml 'service_https_clusterip_port=443' \
>"$manifest" >"$manifest"
kapp deploy --yes --app "$supervisor_app_name" --diff-changes --file "$manifest" kapp deploy --yes --app "$supervisor_app_name" --diff-changes --file "$manifest"
@ -302,7 +303,7 @@ export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER=https://dex.dex.svc.cluster
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER_CA_BUNDLE="${test_ca_bundle_pem}" export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER_CA_BUNDLE="${test_ca_bundle_pem}"
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_ID=pinniped-supervisor export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_ID=pinniped-supervisor
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_SECRET=pinniped-supervisor-secret export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_SECRET=pinniped-supervisor-secret
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://127.0.0.1:12345/some/path/callback export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME=pinny@example.com export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME=pinny@example.com
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD=password export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD=password

View File

@ -136,6 +136,13 @@ func (c *CA) Bundle() []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.caCertBytes}) return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.caCertBytes})
} }
// Pool returns the current CA signing bundle as a *x509.CertPool.
func (c *CA) Pool() *x509.CertPool {
pool := x509.NewCertPool()
pool.AppendCertsFromPEM(c.Bundle())
return pool
}
// Issue a new server certificate for the given identity and duration. // Issue a new server certificate for the given identity and duration.
func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time.Duration) (*tls.Certificate, error) { func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time.Duration) (*tls.Certificate, error) {
// Choose a random 128 bit serial number. // Choose a random 128 bit serial number.

View File

@ -182,6 +182,16 @@ func TestBundle(t *testing.T) {
}) })
} }
func TestPool(t *testing.T) {
t.Run("success", func(t *testing.T) {
ca, err := New(pkix.Name{CommonName: "test"}, 1*time.Hour)
require.NoError(t, err)
got := ca.Pool()
require.Len(t, got.Subjects(), 1)
})
}
type errSigner struct { type errSigner struct {
pubkey crypto.PublicKey pubkey crypto.PublicKey
err error err error

View File

@ -17,6 +17,7 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/go-logr/logr" "github.com/go-logr/logr"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/api/equality" "k8s.io/apimachinery/pkg/api/equality"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
@ -30,6 +31,7 @@ import (
pinnipedcontroller "go.pinniped.dev/internal/controller" pinnipedcontroller "go.pinniped.dev/internal/controller"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/upstreamoidc"
) )
const ( const (
@ -62,21 +64,27 @@ const (
// IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations. // IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations.
type IDPCache interface { type IDPCache interface {
SetIDPList([]provider.UpstreamOIDCIdentityProvider) SetIDPList([]provider.UpstreamOIDCIdentityProviderI)
} }
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
type lruValidatorCache struct{ cache *cache.Expiring } type lruValidatorCache struct{ cache *cache.Expiring }
func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider { type lruValidatorCacheEntry struct {
if result, ok := c.cache.Get(c.cacheKey(spec)); ok { provider *oidc.Provider
return result.(*oidc.Provider) client *http.Client
}
return nil
} }
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) { func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) {
c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL) if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
entry := result.(*lruValidatorCacheEntry)
return entry.provider, entry.client
}
return nil, nil
}
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) {
c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL)
} }
func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} { func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} {
@ -95,8 +103,8 @@ type controller struct {
providers idpinformers.UpstreamOIDCProviderInformer providers idpinformers.UpstreamOIDCProviderInformer
secrets corev1informers.SecretInformer secrets corev1informers.SecretInformer
validatorCache interface { validatorCache interface {
getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client)
putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client)
} }
} }
@ -132,13 +140,13 @@ func (c *controller) Sync(ctx controllerlib.Context) error {
} }
requeue := false requeue := false
validatedUpstreams := make([]provider.UpstreamOIDCIdentityProvider, 0, len(actualUpstreams)) validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams))
for _, upstream := range actualUpstreams { for _, upstream := range actualUpstreams {
valid := c.validateUpstream(ctx, upstream) valid := c.validateUpstream(ctx, upstream)
if valid == nil { if valid == nil {
requeue = true requeue = true
} else { } else {
validatedUpstreams = append(validatedUpstreams, *valid) validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid))
} }
} }
c.cache.SetIDPList(validatedUpstreams) c.cache.SetIDPList(validatedUpstreams)
@ -150,10 +158,14 @@ func (c *controller) Sync(ctx controllerlib.Context) error {
// validateUpstream validates the provided v1alpha1.UpstreamOIDCProvider and returns the validated configuration as a // validateUpstream validates the provided v1alpha1.UpstreamOIDCProvider and returns the validated configuration as a
// provider.UpstreamOIDCIdentityProvider. As a side effect, it also updates the status of the v1alpha1.UpstreamOIDCProvider. // provider.UpstreamOIDCIdentityProvider. As a side effect, it also updates the status of the v1alpha1.UpstreamOIDCProvider.
func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *provider.UpstreamOIDCIdentityProvider { func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *upstreamoidc.ProviderConfig {
result := provider.UpstreamOIDCIdentityProvider{ result := upstreamoidc.ProviderConfig{
Name: upstream.Name, Name: upstream.Name,
Config: &oauth2.Config{
Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes), Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes),
},
UsernameClaim: upstream.Spec.Claims.Username,
GroupsClaim: upstream.Spec.Claims.Groups,
} }
conditions := []*v1alpha1.Condition{ conditions := []*v1alpha1.Condition{
c.validateSecret(upstream, &result), c.validateSecret(upstream, &result),
@ -180,7 +192,7 @@ func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alp
} }
// validateSecret validates the .spec.client.secretName field and returns the appropriate ClientCredentialsValid condition. // validateSecret validates the .spec.client.secretName field and returns the appropriate ClientCredentialsValid condition.
func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition { func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
secretName := upstream.Spec.Client.SecretName secretName := upstream.Spec.Client.SecretName
// Fetch the Secret from informer cache. // Fetch the Secret from informer cache.
@ -217,7 +229,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
} }
// If everything is valid, update the result and set the condition to true. // If everything is valid, update the result and set the condition to true.
result.ClientID = string(clientID) result.Config.ClientID = string(clientID)
result.Config.ClientSecret = string(clientSecret)
return &v1alpha1.Condition{ return &v1alpha1.Condition{
Type: typeClientCredsValid, Type: typeClientCredsValid,
Status: v1alpha1.ConditionTrue, Status: v1alpha1.ConditionTrue,
@ -227,9 +240,9 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
} }
// validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition. // validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition.
func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition { func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
// Get the provider (from cache if possible). // Get the provider and HTTP Client from cache if possible.
discoveredProvider := c.validatorCache.getProvider(&upstream.Spec) discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec)
// If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache. // If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache.
if discoveredProvider == nil { if discoveredProvider == nil {
@ -242,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
Message: err.Error(), Message: err.Error(),
} }
} }
httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}} httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer) discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer)
if err != nil { if err != nil {
@ -255,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
} }
// Update the cache with the newly discovered value. // Update the cache with the newly discovered value.
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient)
} }
// Parse out and validate the discovered authorize endpoint. // Parse out and validate the discovered authorize endpoint.
@ -278,7 +291,9 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
} }
// If everything is valid, update the result and set the condition to true. // If everything is valid, update the result and set the condition to true.
result.AuthorizationURL = *authURL result.Config.Endpoint = discoveredProvider.Endpoint()
result.Provider = discoveredProvider
result.Client = httpClient
return &v1alpha1.Condition{ return &v1alpha1.Condition{
Type: typeOIDCDiscoverySucceeded, Type: typeOIDCDiscoverySucceeded,
Status: v1alpha1.ConditionTrue, Status: v1alpha1.ConditionTrue,

View File

@ -24,9 +24,11 @@ import (
pinnipedfake "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned/fake" pinnipedfake "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned/fake"
pinnipedinformers "go.pinniped.dev/generated/1.19/client/supervisor/informers/externalversions" pinnipedinformers "go.pinniped.dev/generated/1.19/client/supervisor/informers/externalversions"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/testlogger" "go.pinniped.dev/internal/testutil/testlogger"
"go.pinniped.dev/internal/upstreamoidc"
) )
func TestController(t *testing.T) { func TestController(t *testing.T) {
@ -49,6 +51,8 @@ func TestController(t *testing.T) {
testClientID = "test-oidc-client-id" testClientID = "test-oidc-client-id"
testClientSecret = "test-oidc-client-secret" testClientSecret = "test-oidc-client-secret"
testValidSecretData = map[string][]byte{"clientID": []byte(testClientID), "clientSecret": []byte(testClientSecret)} testValidSecretData = map[string][]byte{"clientID": []byte(testClientID), "clientSecret": []byte(testClientSecret)}
testGroupsClaim = "test-groups-claim"
testUsernameClaim = "test-username-claim"
) )
tests := []struct { tests := []struct {
name string name string
@ -56,7 +60,7 @@ func TestController(t *testing.T) {
inputSecrets []runtime.Object inputSecrets []runtime.Object
wantErr string wantErr string
wantLogs []string wantLogs []string
wantResultingCache []provider.UpstreamOIDCIdentityProvider wantResultingCache []provider.UpstreamOIDCIdentityProviderI
wantResultingUpstreams []v1alpha1.UpstreamOIDCProvider wantResultingUpstreams []v1alpha1.UpstreamOIDCProvider
}{ }{
{ {
@ -80,7 +84,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="secret \"test-client-secret\" not found" "name"="test-name" "namespace"="test-namespace" "reason"="SecretNotFound" "type"="ClientCredentialsValid"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="secret \"test-client-secret\" not found" "name"="test-name" "namespace"="test-namespace" "reason"="SecretNotFound" "type"="ClientCredentialsValid"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -126,7 +130,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" has wrong type \"some-other-type\" (should be \"secrets.pinniped.dev/oidc-client\")" "name"="test-name" "namespace"="test-namespace" "reason"="SecretWrongType" "type"="ClientCredentialsValid"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" has wrong type \"some-other-type\" (should be \"secrets.pinniped.dev/oidc-client\")" "name"="test-name" "namespace"="test-namespace" "reason"="SecretWrongType" "type"="ClientCredentialsValid"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -171,7 +175,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" is missing required keys [\"clientID\" \"clientSecret\"]" "name"="test-name" "namespace"="test-namespace" "reason"="SecretMissingKeys" "type"="ClientCredentialsValid"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" is missing required keys [\"clientID\" \"clientSecret\"]" "name"="test-name" "namespace"="test-namespace" "reason"="SecretMissingKeys" "type"="ClientCredentialsValid"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -219,7 +223,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -267,7 +271,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: no certificates found" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: no certificates found" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: no certificates found" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: no certificates found" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -312,7 +316,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to perform OIDC discovery against \"invalid-url\"" "reason"="Unreachable" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to perform OIDC discovery against \"invalid-url\"" "reason"="Unreachable" "status"="False" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to perform OIDC discovery against \"invalid-url\"" "name"="test-name" "namespace"="test-namespace" "reason"="Unreachable" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to perform OIDC discovery against \"invalid-url\"" "name"="test-name" "namespace"="test-namespace" "reason"="Unreachable" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -358,7 +362,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -404,7 +408,7 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`,
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{}, wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -437,6 +441,7 @@ func TestController(t *testing.T) {
TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64}, TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64},
Client: v1alpha1.OIDCClient{SecretName: testSecretName}, Client: v1alpha1.OIDCClient{SecretName: testSecretName},
AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: append(testAdditionalScopes, "xyz", "openid")}, AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: append(testAdditionalScopes, "xyz", "openid")},
Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim},
}, },
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
Phase: "Error", Phase: "Error",
@ -455,12 +460,16 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`,
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{ wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{
&oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: testName, Name: testName,
ClientID: testClientID, ClientID: testClientID,
AuthorizationURL: *testIssuerAuthorizeURL, AuthorizationURL: *testIssuerAuthorizeURL,
Scopes: append(testExpectedScopes, "xyz"), Scopes: append(testExpectedScopes, "xyz"),
}}, UsernameClaim: testUsernameClaim,
GroupsClaim: testGroupsClaim,
},
},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -481,6 +490,7 @@ func TestController(t *testing.T) {
TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64}, TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64},
Client: v1alpha1.OIDCClient{SecretName: testSecretName}, Client: v1alpha1.OIDCClient{SecretName: testSecretName},
AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: testAdditionalScopes}, AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: testAdditionalScopes},
Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim},
}, },
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
Phase: "Ready", Phase: "Ready",
@ -499,12 +509,16 @@ func TestController(t *testing.T) {
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`,
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`, `upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
}, },
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{ wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{
&oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: testName, Name: testName,
ClientID: testClientID, ClientID: testClientID,
AuthorizationURL: *testIssuerAuthorizeURL, AuthorizationURL: *testIssuerAuthorizeURL,
Scopes: testExpectedScopes, Scopes: testExpectedScopes,
}}, UsernameClaim: testUsernameClaim,
GroupsClaim: testGroupsClaim,
},
},
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{ wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234}, ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
Status: v1alpha1.UpstreamOIDCProviderStatus{ Status: v1alpha1.UpstreamOIDCProviderStatus{
@ -527,7 +541,9 @@ func TestController(t *testing.T) {
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
testLog := testlogger.New(t) testLog := testlogger.New(t)
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetIDPList([]provider.UpstreamOIDCIdentityProvider{{Name: "initial-entry"}}) cache.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{
&upstreamoidc.ProviderConfig{Name: "initial-entry"},
})
controller := New( controller := New(
cache, cache,
@ -551,7 +567,18 @@ func TestController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n")) require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n"))
require.ElementsMatch(t, tt.wantResultingCache, cache.GetIDPList())
actualIDPList := cache.GetIDPList()
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
for i := range actualIDPList {
actualIDP := actualIDPList[i].(*upstreamoidc.ProviderConfig)
require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName())
require.Equal(t, tt.wantResultingCache[i].GetClientID(), actualIDP.GetClientID())
require.Equal(t, tt.wantResultingCache[i].GetAuthorizationURL().String(), actualIDP.GetAuthorizationURL().String())
require.Equal(t, tt.wantResultingCache[i].GetUsernameClaim(), actualIDP.GetUsernameClaim())
require.Equal(t, tt.wantResultingCache[i].GetGroupsClaim(), actualIDP.GetGroupsClaim())
require.ElementsMatch(t, tt.wantResultingCache[i].GetScopes(), actualIDP.GetScopes())
}
actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{}) actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{})
require.NoError(t, err) require.NoError(t, err)

View File

@ -30,7 +30,7 @@ const (
ErrSecretTypeMismatch = constable.Error("secret storage data has incorrect type") ErrSecretTypeMismatch = constable.Error("secret storage data has incorrect type")
ErrSecretLabelMismatch = constable.Error("secret storage data has incorrect label") ErrSecretLabelMismatch = constable.Error("secret storage data has incorrect label")
ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version") // TODO do we need this? ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version")
) )
type Storage interface { type Storage interface {
@ -139,7 +139,7 @@ func (s *secretsStorage) toSecret(signature, resourceVersion string, data JSON)
Labels: map[string]string{ Labels: map[string]string{
secretLabelKey: s.resource, // make it easier to find this stuff via kubectl secretLabelKey: s.resource, // make it easier to find this stuff via kubectl
}, },
OwnerReferences: nil, // TODO we should set this to make sure stuff gets clean up OwnerReferences: nil,
}, },
Data: map[string][]byte{ Data: map[string][]byte{
secretDataKey: buf, secretDataKey: buf,

View File

@ -62,17 +62,17 @@ func TestStorage(t *testing.T) {
}{ }{
{ {
name: "get non-existent", name: "get non-existent",
resource: "authorization-codes", resource: "authcode",
mocks: nil, mocks: nil,
run: func(t *testing.T, storage Storage) error { run: func(t *testing.T, storage Storage) error {
_, err := storage.Get(ctx, "not-exists", nil) _, err := storage.Get(ctx, "not-exists", nil)
return err return err
}, },
wantActions: []coretesting.Action{ wantActions: []coretesting.Action{
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-t2fx46yyvs3a"), coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-t2fx46yyvs3a"),
}, },
wantSecrets: nil, wantSecrets: nil,
wantErr: `failed to get authorization-codes for signature not-exists: secrets "pinniped-storage-authorization-codes-t2fx46yyvs3a" not found`, wantErr: `failed to get authcode for signature not-exists: secrets "pinniped-storage-authcode-t2fx46yyvs3a" not found`,
}, },
{ {
name: "delete non-existent", name: "delete non-existent",

View File

@ -1,334 +0,0 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package authorizationcode
import (
"context"
"crypto/ed25519"
"crypto/x509"
"encoding/json"
"math/rand"
"net/url"
"strings"
"testing"
"time"
fuzz "github.com/google/gofuzz"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/handler/openid"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/kubernetes/fake"
coretesting "k8s.io/client-go/testing"
)
func TestAuthorizeCodeStorage(t *testing.T) {
ctx := context.Background()
secretsGVR := schema.GroupVersionResource{
Group: "",
Version: "v1",
Resource: "secrets",
}
const namespace = "test-ns"
type mocker interface {
AddReactor(verb, resource string, reaction coretesting.ReactionFunc)
PrependReactor(verb, resource string, reaction coretesting.ReactionFunc)
Tracker() coretesting.ObjectTracker
}
tests := []struct {
name string
mocks func(*testing.T, mocker)
run func(*testing.T, oauth2.AuthorizeCodeStorage) error
wantActions []coretesting.Action
wantSecrets []corev1.Secret
wantErr string
}{
{
name: "create, get, invalidate standard flow",
mocks: nil,
run: func(t *testing.T, storage oauth2.AuthorizeCodeStorage) error {
request := &fosite.AuthorizeRequest{
ResponseTypes: fosite.Arguments{"not-code"},
RedirectURI: &url.URL{
Scheme: "",
Opaque: "weee",
User: &url.Userinfo{},
Host: "",
Path: "/callback",
RawPath: "",
ForceQuery: false,
RawQuery: "",
Fragment: "",
},
State: "stated",
HandledResponseTypes: fosite.Arguments{"not-type"},
Request: fosite.Request{
ID: "abcd-1",
RequestedAt: time.Time{},
Client: &fosite.DefaultOpenIDConnectClient{
DefaultClient: &fosite.DefaultClient{
ID: "pinny",
Secret: nil,
RedirectURIs: nil,
GrantTypes: nil,
ResponseTypes: nil,
Scopes: nil,
Audience: nil,
Public: true,
},
JSONWebKeysURI: "where",
JSONWebKeys: nil,
TokenEndpointAuthMethod: "something",
RequestURIs: nil,
RequestObjectSigningAlgorithm: "",
TokenEndpointAuthSigningAlgorithm: "",
},
RequestedScope: nil,
GrantedScope: nil,
Form: url.Values{"key": []string{"val"}},
Session: &openid.DefaultSession{
Claims: nil,
Headers: nil,
ExpiresAt: nil,
Username: "snorlax",
Subject: "panda",
},
RequestedAudience: nil,
GrantedAudience: nil,
},
}
err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request)
require.NoError(t, err)
newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
require.NoError(t, err)
require.Equal(t, request, newRequest)
return storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature")
},
wantActions: []coretesting.Action{
coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authorization-codes",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"active":true,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authorization-codes",
}),
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"),
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"),
coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authorization-codes",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authorization-codes",
}),
},
wantSecrets: []corev1.Secret{
{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4",
Namespace: namespace,
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authorization-codes",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authorization-codes",
},
},
wantErr: "",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
client := fake.NewSimpleClientset()
if tt.mocks != nil {
tt.mocks(t, client)
}
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
err := tt.run(t, storage)
require.Equal(t, tt.wantErr, errString(err))
require.Equal(t, tt.wantActions, client.Actions())
actualSecrets, err := secrets.List(ctx, metav1.ListOptions{})
require.NoError(t, err)
require.Equal(t, tt.wantSecrets, actualSecrets.Items)
})
}
}
func errString(err error) string {
if err == nil {
return ""
}
return err.Error()
}
// TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session.
// It will detect any changes to fosite.AuthorizeRequest and guarantees that all interface types have concrete implementations.
func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) {
validSession := NewValidEmptyAuthorizeCodeSession()
// sanity check our valid session
extractedRequest, err := validateAndExtractAuthorizeRequest(validSession.Request)
require.NoError(t, err)
require.Equal(t, validSession.Request, extractedRequest)
// checked above
defaultClient := validSession.Request.Request.Client.(*fosite.DefaultOpenIDConnectClient)
defaultSession := validSession.Request.Request.Session.(*openid.DefaultSession)
// makes it easier to use a raw string
replacer := strings.NewReplacer("`", "a")
randString := func(c fuzz.Continue) string {
for {
s := c.RandString()
if len(s) == 0 {
continue // skip empty string
}
return replacer.Replace(s)
}
}
// deterministic fuzzing of fosite.AuthorizeRequest
f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs(
// these functions guarantee that these are the only interface types we need to fill out
// if fosite.AuthorizeRequest changes to add more, the fuzzer will panic
func(fc *fosite.Client, c fuzz.Continue) {
c.Fuzz(defaultClient)
*fc = defaultClient
},
func(fs *fosite.Session, c fuzz.Continue) {
c.Fuzz(defaultSession)
*fs = defaultSession
},
// these types contain an interface{} that we need to handle
// this is safe because we explicitly provide the openid.DefaultSession concrete type
func(value *map[string]interface{}, c fuzz.Continue) {
// cover all the JSON data types just in case
*value = map[string]interface{}{
randString(c): float64(c.Intn(1 << 32)),
randString(c): map[string]interface{}{
randString(c): []interface{}{float64(c.Intn(1 << 32))},
randString(c): map[string]interface{}{
randString(c): nil,
randString(c): map[string]interface{}{
randString(c): c.RandBool(),
},
},
},
}
},
// JWK contains an interface{} Key that we need to handle
// this is safe because JWK explicitly implements JSON marshalling and unmarshalling
func(jwk *jose.JSONWebKey, c fuzz.Continue) {
key, _, err := ed25519.GenerateKey(c)
require.NoError(t, err)
jwk.Key = key
// set these fields to make the .Equal comparison work
jwk.Certificates = []*x509.Certificate{}
jwk.CertificatesURL = &url.URL{}
jwk.CertificateThumbprintSHA1 = []byte{}
jwk.CertificateThumbprintSHA256 = []byte{}
},
// set this to make the .Equal comparison work
// this is safe because Time explicitly implements JSON marshalling and unmarshalling
func(tp *time.Time, c fuzz.Continue) {
*tp = time.Unix(c.Int63n(1<<32), c.Int63n(1<<32)).UTC()
},
// make random strings that do not contain any ` characters
func(s *string, c fuzz.Continue) {
*s = randString(c)
},
// handle string type alias
func(s *fosite.TokenType, c fuzz.Continue) {
*s = fosite.TokenType(randString(c))
},
// handle string type alias
func(s *fosite.Arguments, c fuzz.Continue) {
n := c.Intn(3) + 1 // 1 to 3 items
arguments := make(fosite.Arguments, n)
for i := range arguments {
arguments[i] = randString(c)
}
*s = arguments
},
)
f.Fuzz(validSession)
const name = "fuzz" // value is irrelevant
ctx := context.Background()
secrets := fake.NewSimpleClientset().CoreV1().Secrets(name)
storage := New(secrets)
// issue a create using the fuzzed request to confirm that marshalling works
err = storage.CreateAuthorizeCodeSession(ctx, name, validSession.Request)
require.NoError(t, err)
// retrieve a copy of the fuzzed request from storage to confirm that unmarshalling works
newRequest, err := storage.GetAuthorizeCodeSession(ctx, name, nil)
require.NoError(t, err)
// the fuzzed request and the copy from storage should be exactly the same
require.Equal(t, validSession.Request, newRequest)
secretList, err := secrets.List(ctx, metav1.ListOptions{})
require.NoError(t, err)
require.Len(t, secretList.Items, 1)
authorizeCodeSessionJSONFromStorage := string(secretList.Items[0].Data["pinniped-storage-data"])
// set these to match CreateAuthorizeCodeSession so that .JSONEq works
validSession.Active = true
validSession.Version = "1"
validSessionJSONBytes, err := json.MarshalIndent(validSession, "", "\t")
require.NoError(t, err)
authorizeCodeSessionJSONFromFuzzing := string(validSessionJSONBytes)
// the fuzzed session and storage session should have identical JSON
require.JSONEq(t, authorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromStorage)
// while the fuzzer will panic if AuthorizeRequest changes in a way that cannot be fuzzed,
// if it adds a new field that can be fuzzed, this check will fail
// thus if AuthorizeRequest changes, we will detect it here (though we could possibly miss an omitempty field)
require.Equal(t, ExpectedAuthorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromFuzzing)
}

View File

@ -16,11 +16,11 @@ import (
"go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/constable"
"go.pinniped.dev/internal/crud" "go.pinniped.dev/internal/crud"
"go.pinniped.dev/internal/fositestorage"
) )
const ( const (
ErrInvalidAuthorizeRequestType = constable.Error("authorization request must be of type fosite.AuthorizeRequest") ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must be present")
ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must not be nil")
ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version") ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version")
authorizeCodeStorageVersion = "1" authorizeCodeStorageVersion = "1"
@ -34,25 +34,24 @@ type authorizeCodeStorage struct {
type AuthorizeCodeSession struct { type AuthorizeCodeSession struct {
Active bool `json:"active"` Active bool `json:"active"`
Request *fosite.AuthorizeRequest `json:"request"` Request *fosite.Request `json:"request"`
Version string `json:"version"` Version string `json:"version"`
} }
func New(secrets corev1client.SecretInterface) oauth2.AuthorizeCodeStorage { func New(secrets corev1client.SecretInterface) oauth2.AuthorizeCodeStorage {
return &authorizeCodeStorage{storage: crud.New("authorization-codes", secrets)} return &authorizeCodeStorage{storage: crud.New("authcode", secrets)}
} }
func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error { func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error {
// this conversion assumes that we do not wrap the default type in any way // This conversion assumes that we do not wrap the default type in any way
// i.e. we use the default fosite.OAuth2Provider.NewAuthorizeRequest implementation // i.e. we use the default fosite.OAuth2Provider.NewAuthorizeRequest implementation
// note that because this type is serialized and stored in Kube, we cannot easily change the implementation later // note that because this type is serialized and stored in Kube, we cannot easily change the implementation later
// TODO hydra uses the fosite.Request struct and ignores the extra fields in fosite.AuthorizeRequest request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester)
request, err := validateAndExtractAuthorizeRequest(requester)
if err != nil { if err != nil {
return err return err
} }
// TODO hydra stores specific fields from the requester // Note, in case it is helpful, that Hydra stores specific fields from the requester:
// request ID // request ID
// requestedAt // requestedAt
// OAuth client ID // OAuth client ID
@ -70,12 +69,11 @@ func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, s
} }
func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) {
// TODO hydra uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes // Note, in case it is helpful, that Hydra:
// - uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes
// TODO hydra gets the client from its DB as a concrete type via client ID, // - gets the client from its DB as a concrete type via client ID, the hydra memory client just validates that the
// the hydra memory client just validates that the client ID exists // client ID exists
// - hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length
// TODO hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length
session, _, err := a.getSession(ctx, signature) session, _, err := a.getSession(ctx, signature)
@ -88,8 +86,6 @@ func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, sign
} }
func (a *authorizeCodeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error { func (a *authorizeCodeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error {
// TODO write garbage collector for these codes
session, rv, err := a.getSession(ctx, signature) session, rv, err := a.getSession(ctx, signature)
if err != nil { if err != nil {
return err return err
@ -123,7 +119,7 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string)
ErrInvalidAuthorizeRequestVersion, signature, version, authorizeCodeStorageVersion) ErrInvalidAuthorizeRequestVersion, signature, version, authorizeCodeStorageVersion)
} }
if session.Request == nil { if session.Request.ID == "" {
return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData) return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData)
} }
@ -137,31 +133,13 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string)
func NewValidEmptyAuthorizeCodeSession() *AuthorizeCodeSession { func NewValidEmptyAuthorizeCodeSession() *AuthorizeCodeSession {
return &AuthorizeCodeSession{ return &AuthorizeCodeSession{
Request: &fosite.AuthorizeRequest{ Request: &fosite.Request{
Request: fosite.Request{
Client: &fosite.DefaultOpenIDConnectClient{}, Client: &fosite.DefaultOpenIDConnectClient{},
Session: &openid.DefaultSession{}, Session: &openid.DefaultSession{},
}, },
},
} }
} }
func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.AuthorizeRequest, error) {
request, ok1 := requester.(*fosite.AuthorizeRequest)
if !ok1 {
return nil, ErrInvalidAuthorizeRequestType
}
_, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient)
_, ok3 := request.Session.(*openid.DefaultSession)
valid := ok2 && ok3
if !valid {
return nil, ErrInvalidAuthorizeRequestType
}
return request, nil
}
var _ interface { var _ interface {
Is(error) bool Is(error) bool
Unwrap() error Unwrap() error
@ -189,59 +167,37 @@ func (e *errSerializationFailureWithCause) Error() string {
const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{
"active": true, "active": true,
"request": { "request": {
"responseTypes": [ "id": "嫎l蟲aƖ啘艿",
"¥Îʒ襧.ɕ7崛瀇莒AȒ[ɠ牐7#$ɭ", "requestedAt": "2082-11-10T18:36:11.627253638Z",
".5ȿELj9ûF済(D疻翋膗",
"螤Yɫüeɯ紤邥翔勋\\RBʒ;-"
],
"redirectUri": {
"Scheme": "ħesƻU赒M喦_ģ",
"Opaque": "Ġ/_章Ņ缘T蝟NJ儱礹燃ɢ",
"User": {},
"Host": "ȳ4螘Wo",
"Path": "}i{",
"RawPath": "5Dža丝eF0eė鱊hǒx蔼Q",
"ForceQuery": true,
"RawQuery": "熤1bbWV",
"Fragment": "ȋc剠鏯ɽÿ¸",
"RawFragment": "qƤ"
},
"state": "@n,x竘Şǥ嗾稀'ã击漰怼禝穞梠Ǫs",
"handledResponseTypes": [
"m\"e尚鬞ƻɼ抹d誉y鿜Ķ"
],
"id": "ō澩ć|3U2Ǜl霨ǦǵpƉ",
"requestedAt": "1989-11-05T22:02:31.105295894Z",
"client": { "client": {
"id": "[:c顎疻紵D", "id": "!ſɄĈp[述齛ʘUȻ.5ȿE",
"client_secret": "mQ==", "client_secret": "UQ==",
"redirect_uris": [ "redirect_uris": [
"恣S@T嵇LJV,Æ櫔袆鋹奘菲", "ǣ珑 ʑ飶畛Ȳ螤Yɫüeɯ紤邥翔勋\\",
"ãƻʚ肈ą8O+a駣Ʉɼk瘸'鴵y" "Bʒ;",
"鿃攴Ųęʍ鎾ʦ©cÏN,Ġ/_"
], ],
"grant_types": [ "grant_types": [
".湆ê\"唐", "憉sHĒ尥窘挼Ŀʼn"
"曎餄FxD溪躲珫ÈşɜȨû臓嬣\"ǃŤz"
], ],
"response_types": [ "response_types": [
"Ņʘʟ車sʊ儓JǐŪɺǣy|耑ʄ" "4",
"ʄÔ@}i{絧遗Ū^ȝĸ谋Vʋ鱴閇T"
], ],
"scopes": [ "scopes": [
"Ą", "R鴝順諲ŮŚ节ȭŀȋc剠鏯ɽÿ¸"
"萙Į(潶饏熞ĝƌĆ1",
"əȤ4Į筦p煖鵄$睱奐耡q"
], ],
"audience": [ "audience": [
"Ʃǣ鿫/Ò敫ƤV" "Ƥ"
], ],
"public": true, "public": true,
"jwks_uri": "ȩđ[嬧鱒Ȁ彆媚杨嶒ĤG", "jwks_uri": "BA瘪囷ɫCʄɢ雐譄uée'",
"jwks": { "jwks": {
"keys": [ "keys": [
{ {
"kty": "OKP", "kty": "OKP",
"crv": "Ed25519", "crv": "Ed25519",
"x": "JmA-6KpjzqKu0lq9OiB6ORL4s2UzBFPsE1hm6vESeXM", "x": "nK9xgX_iN7u3u_i8YOO7ZRT_WK028Vd_nhtsUu7Eo6E",
"x5u": { "x5u": {
"Scheme": "", "Scheme": "",
"Opaque": "", "Opaque": "",
@ -258,24 +214,7 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{
{ {
"kty": "OKP", "kty": "OKP",
"crv": "Ed25519", "crv": "Ed25519",
"x": "LbRC1_3HEe5o7Japk9jFp3_7Ou7Gi2gpqrVrIi0eLDQ", "x": "UbbswQgzWhfGCRlwQmMp6fw_HoIoqkIaKT-2XN2fuYU",
"x5u": {
"Scheme": "",
"Opaque": "",
"User": null,
"Host": "",
"Path": "",
"RawPath": "",
"ForceQuery": false,
"RawQuery": "",
"Fragment": "",
"RawFragment": ""
}
},
{
"kty": "OKP",
"crv": "Ed25519",
"x": "Ovk4DF8Yn3mkULuTqnlGJxFnKGu9EL6Xcf2Nql9lK3c",
"x5u": { "x5u": {
"Scheme": "", "Scheme": "",
"Opaque": "", "Opaque": "",
@ -291,91 +230,95 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{
} }
] ]
}, },
"token_endpoint_auth_method": "\u0026(K鵢Kj ŏ9Q韉Ķ%嶑輫ǘ(", "token_endpoint_auth_method": "ŚǗƳȕ暭Q0ņP羾,塐",
"request_uris": [ "request_uris": [
":", "lj翻LH^俤µDzɹ@©|\u003eɃ",
"6ě#嫀^xz Ū胧r" "[:c顎疻紵D"
], ],
"request_object_signing_alg": "^¡!犃ĹĐJí¿ō擫ų懫砰¿", "request_object_signing_alg": "m1Ì恣S@T嵇LJV,Æ櫔袆鋹奘",
"token_endpoint_auth_signing_alg": "ƈŮå" "token_endpoint_auth_signing_alg": "Fãƻʚ肈ą8O+a駣"
}, },
"scopes": [ "scopes": [
"阃.Ù頀ʌGa皶竇瞍涘¹", "ɼk瘸'鴵yſǮŁ±\u003eFA曎餄FxD溪",
"ȽŮ切衖庀ŰŒ矠", "綻N镪p赌h%桙dĽ"
"楓)馻řĝǕ菸Tĕ1伞柲\u003c\"ʗȆ\\雤"
], ],
"grantedScopes": [ "grantedScopes": [
"ơ鮫R嫁ɍUƞ9+u!Ȱ", "癗E]Ņʘʟ車s"
"}Ă岜"
], ],
"form": { "form": {
"旸Ť/Õ薝隧;綡,鼞纂=": [ "蹬器ķ8ŷ萒寎廭#疶昄Ą-Ƃƞ轵": [
"[滮]憀", "熞ĝƌĆ1ȇyǴ濎=Tʉȼʁŀ\u003c",
"3\u003eÙœ蓄UK嗤眇疟Țƒ1v¸KĶ" "耡q戨稞R÷mȵg釽[ƞ@",
"đ[嬧鱒Ȁ彆媚杨嶒ĤGÀ吧Lŷ"
],
"餟": [
"蒍z\u0026(K鵢Kj ŏ9Q韉Ķ%",
"輫ǘ(¨Ƞ亱6ě#嫀^xz ",
"@耢ɝ^¡!犃ĹĐJí¿ō擫"
] ]
}, },
"session": { "session": {
"Claims": { "Claims": {
"JTI": "};Ų斻遟a衪荖舃", "JTI": "懫砰¿C筽娴ƓaPu镈賆ŗɰ",
"Issuer": "芠顋敀拲h蝺$!", "Issuer": "皶竇瞍涘¹焕iǢǽɽĺŧ",
"Subject": "}j%(=ſ氆]垲莲顇", "Subject": "矠M6ɡǜg炾ʙ$%o6肿Ȫ",
"Audience": [ "Audience": [
"彑V\\廳蟕Țǡ蔯ʠ浵Ī龉磈螖畭5", "ƌÙ鯆GQơ鮫R嫁ɍUƞ9+u!Ȱ踾$"
"渇Ȯʕc"
], ],
"Nonce": "Ǖ=rlƆ褡{ǏS", "Nonce": "us旸Ť/Õ薝隧;綡,鼞",
"ExpiresAt": "1975-11-17T14:21:34.205609651Z", "ExpiresAt": "2065-11-30T13:47:03.613000626Z",
"IssuedAt": "2104-07-03T15:40:03.66710966Z", "IssuedAt": "1976-02-22T09:57:20.479850437Z",
"RequestedAt": "2031-05-18T05:14:19.449350555Z", "RequestedAt": "2016-04-13T04:18:53.648949323Z",
"AuthTime": "2018-01-27T07:55:06.056862114Z", "AuthTime": "2098-07-12T04:38:54.034043015Z",
"AccessTokenHash": "鹰肁躧", "AccessTokenHash": "滮]",
"AuthenticationContextClassReference": "", "AuthenticationContextClassReference": "°3\u003eÙ",
"AuthenticationMethodsReference": "DQh:uȣ", "AuthenticationMethodsReference": "k?µ鱔ǤÂ",
"CodeHash": "ɘȏıȒ諃龟", "CodeHash": "Țƒ1v¸KĶ跭};",
"Extra": { "Extra": {
"a": { "=ſ氆": {
"^i臏f恡ƨ彮": { "Ƿī,廖ʡ彑V\\廳蟕Ț": [
"DĘ敨ýÏʥZq7烱藌\\": null, 843216989
"V": { ],
"őŧQĝ微'X焌襱ǭɕņ殥!_n": false "蔯ʠ浵Ī": {
"H\"nǕ=rlƆ褡{ǏSȳŅ": {
"Žg": false
},
"枱鰧ɛ鸁A渇": null
} }
}, },
"Ż猁": [ "斻遟a衪荖舃9闄岈锘肺ńʥƕU}j%": 2520197933
1706822246
]
},
"Ò椪)ɫqň2搞Ŀ高摠鲒鿮禗O": 1233332227
} }
}, },
"Headers": { "Headers": {
"Extra": { "Extra": {
"?戋璖$9\u0026": { "熒ɘȏıȒ諃龟ŴŠ'耐Ƭ扵ƹ玄ɕwL": {
"µcɕ餦ÑEǰ哤癨浦浏1R": [ "ýÏʥZq7烱藌\\捀¿őŧQ": {
3761201123 "微'X焌襱ǭɕņ殥!_": null,
], "荇届UȚ?戋璖$9\u00269舋": {
"頓ć§蚲6rǦ\u003cqċ": { "ɕ餦ÑEǰ哤癨浦浏1Rk頓ć§蚲6": true
"Łʀ§ȏœɽDz斡冭ȸěaʜD捛?½ʀ+": null,
"ɒúIJ誠ƉyÖ.峷1藍殙菥趏": {
"jHȬȆ#)\u003cX": true
}
} }
}, },
"U": 1354158262 "鲒鿮禗O暒aJP鐜?ĮV嫎h譭ȉ]DĘ": [
954647573
]
},
"皩Ƭ}Ɇ.雬Ɨ´唁": 1572524915
} }
}, },
"ExpiresAt": { "ExpiresAt": {
"\"嘬ȹĹaó剺撱Ȱ": "1985-09-09T04:35:40.533197189Z", "\u003cqċ譈8ŪɎP绿MÅ": "2031-10-18T22:07:34.950803105Z",
"ʆ\u003e": "1998-08-07T05:37:11.759718906Z", "ȸěaʜD捛?½ʀ+Ċ偢镳ʬÍɷȓ\u003c": "2049-05-13T15:27:20.968432454Z"
"柏ʒ鴙*鸆偡Ȓ肯Ûx": "2036-12-19T06:36:14.414805124Z"
}, },
"Username": "qmʎaðƠ绗ʢ緦Hū", "Username": "1藍殙菥趏酱Nʎ\u0026^横懋ƶ峦Fïȫƅw",
"Subject": "屾Ê窢ɋ鄊qɠ谫ǯǵƕ牀1鞊\\ȹ)" "Subject": "檾ĩĆ爨4犹|v炩f柏ʒ鴙*鸆偡"
}, },
"requestedAudience": [ "requestedAudience": [
"鉍商OɄƣ圔,xĪɏV鵅砍" "肯Ûx穞Ƀ",
"ź蕴3ǐ薝Ƅ腲=ʐ诂鱰屾Ê窢ɋ鄊qɠ谫"
], ],
"grantedAudience": [ "grantedAudience": [
"C笜嚯\u003cǐšɚĀĥʋ6鉅\\þc涎漄Ɨ腼" "ǵƕ牀1鞊\\ȹ)}鉍商OɄƣ圔,xĪ",
"悾xn冏裻摼0Ʈ蚵Ȼ塕»£#稏扟X"
] ]
}, },
"version": "1" "version": "1"

View File

@ -0,0 +1,401 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package authorizationcode
import (
"context"
"crypto/ed25519"
"crypto/x509"
"encoding/json"
"fmt"
"math/rand"
"net/url"
"strings"
"testing"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
fuzz "github.com/google/gofuzz"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/kubernetes/fake"
kubetesting "k8s.io/client-go/testing"
"go.pinniped.dev/internal/fositestorage"
)
const namespace = "test-ns"
func TestAuthorizationCodeStorage(t *testing.T) {
ctx := context.Background()
secretsGVR := schema.GroupVersionResource{
Group: "",
Version: "v1",
Resource: "secrets",
}
wantActions := []kubetesting.Action{
kubetesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authcode",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"active":true,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authcode",
}),
kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"),
kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"),
kubetesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authcode",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authcode",
}),
}
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
request := &fosite.Request{
ID: "abcd-1",
RequestedAt: time.Time{},
Client: &fosite.DefaultOpenIDConnectClient{
DefaultClient: &fosite.DefaultClient{
ID: "pinny",
Secret: nil,
RedirectURIs: nil,
GrantTypes: nil,
ResponseTypes: nil,
Scopes: nil,
Audience: nil,
Public: true,
},
JSONWebKeysURI: "where",
JSONWebKeys: nil,
TokenEndpointAuthMethod: "something",
RequestURIs: nil,
RequestObjectSigningAlgorithm: "",
TokenEndpointAuthSigningAlgorithm: "",
},
RequestedScope: nil,
GrantedScope: nil,
Form: url.Values{"key": []string{"val"}},
Session: &openid.DefaultSession{
Claims: nil,
Headers: nil,
ExpiresAt: nil,
Username: "snorlax",
Subject: "panda",
},
RequestedAudience: nil,
GrantedAudience: nil,
}
err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request)
require.NoError(t, err)
newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
require.NoError(t, err)
require.Equal(t, request, newRequest)
err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature")
require.NoError(t, err)
require.Equal(t, wantActions, client.Actions())
// Doing a Get on an invalidated session should still return the session, but also return an error.
invalidatedRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
require.EqualError(t, err, "authorization code session for fancy-signature has already been used: Authorization code has ben invalidated")
require.Equal(t, "abcd-1", invalidatedRequest.GetID())
}
func TestGetNotFound(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
_, notFoundErr := storage.GetAuthorizeCodeSession(ctx, "non-existent-signature", nil)
require.EqualError(t, notFoundErr, "not_found")
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
}
func TestInvalidateWhenNotFound(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
notFoundErr := storage.InvalidateAuthorizeCodeSession(ctx, "non-existent-signature")
require.EqualError(t, notFoundErr, "not_found")
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
}
func TestInvalidateWhenConflictOnUpdateHappens(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
client.PrependReactor("update", "secrets", func(_ kubetesting.Action) (bool, runtime.Object, error) {
return true, nil, apierrors.NewConflict(schema.GroupResource{
Group: "",
Resource: "secrets",
}, "some-secret-name", fmt.Errorf("there was a conflict"))
})
request := &fosite.Request{
ID: "some-request-id",
Client: &fosite.DefaultOpenIDConnectClient{},
Session: &openid.DefaultSession{},
}
err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request)
require.NoError(t, err)
err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature")
require.EqualError(t, err, `The request could not be completed due to concurrent access: failed to update authcode for signature fancy-signature at resource version : Operation cannot be fulfilled on secrets "some-secret-name": there was a conflict`)
}
func TestWrongVersion(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
secret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authcode",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version", "active": true}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authcode",
}
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
require.NoError(t, err)
_, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
require.EqualError(t, err, "authorization request data has wrong version: authorization code session for fancy-signature has version not-the-right-version instead of 1")
}
func TestNilSessionRequest(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
secret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "authcode",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value", "version":"1", "active": true}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/authcode",
}
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
require.NoError(t, err)
_, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
require.EqualError(t, err, "malformed authorization code session for fancy-signature: authorization request data must be present")
}
func TestCreateWithNilRequester(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", nil)
require.EqualError(t, err, "requester must be of type fosite.Request")
}
func TestCreateWithWrongRequesterDataTypes(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
request := &fosite.Request{
Session: nil,
Client: &fosite.DefaultOpenIDConnectClient{},
}
err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request)
require.EqualError(t, err, "requester's session must be of type openid.DefaultSession")
request = &fosite.Request{
Session: &openid.DefaultSession{},
Client: nil,
}
err = storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request)
require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient")
}
// TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session.
// It will detect any changes to fosite.AuthorizeRequest and guarantees that all interface types have concrete implementations.
func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) {
validSession := NewValidEmptyAuthorizeCodeSession()
// sanity check our valid session
extractedRequest, err := fositestorage.ValidateAndExtractAuthorizeRequest(validSession.Request)
require.NoError(t, err)
require.Equal(t, validSession.Request, extractedRequest)
// checked above
defaultClient := validSession.Request.Client.(*fosite.DefaultOpenIDConnectClient)
defaultSession := validSession.Request.Session.(*openid.DefaultSession)
// makes it easier to use a raw string
replacer := strings.NewReplacer("`", "a")
randString := func(c fuzz.Continue) string {
for {
s := c.RandString()
if len(s) == 0 {
continue // skip empty string
}
return replacer.Replace(s)
}
}
// deterministic fuzzing of fosite.Request
f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs(
// these functions guarantee that these are the only interface types we need to fill out
// if fosite.Request changes to add more, the fuzzer will panic
func(fc *fosite.Client, c fuzz.Continue) {
c.Fuzz(defaultClient)
*fc = defaultClient
},
func(fs *fosite.Session, c fuzz.Continue) {
c.Fuzz(defaultSession)
*fs = defaultSession
},
// these types contain an interface{} that we need to handle
// this is safe because we explicitly provide the openid.DefaultSession concrete type
func(value *map[string]interface{}, c fuzz.Continue) {
// cover all the JSON data types just in case
*value = map[string]interface{}{
randString(c): float64(c.Intn(1 << 32)),
randString(c): map[string]interface{}{
randString(c): []interface{}{float64(c.Intn(1 << 32))},
randString(c): map[string]interface{}{
randString(c): nil,
randString(c): map[string]interface{}{
randString(c): c.RandBool(),
},
},
},
}
},
// JWK contains an interface{} Key that we need to handle
// this is safe because JWK explicitly implements JSON marshalling and unmarshalling
func(jwk *jose.JSONWebKey, c fuzz.Continue) {
key, _, err := ed25519.GenerateKey(c)
require.NoError(t, err)
jwk.Key = key
// set these fields to make the .Equal comparison work
jwk.Certificates = []*x509.Certificate{}
jwk.CertificatesURL = &url.URL{}
jwk.CertificateThumbprintSHA1 = []byte{}
jwk.CertificateThumbprintSHA256 = []byte{}
},
// set this to make the .Equal comparison work
// this is safe because Time explicitly implements JSON marshalling and unmarshalling
func(tp *time.Time, c fuzz.Continue) {
*tp = time.Unix(c.Int63n(1<<32), c.Int63n(1<<32)).UTC()
},
// make random strings that do not contain any ` characters
func(s *string, c fuzz.Continue) {
*s = randString(c)
},
// handle string type alias
func(s *fosite.TokenType, c fuzz.Continue) {
*s = fosite.TokenType(randString(c))
},
// handle string type alias
func(s *fosite.Arguments, c fuzz.Continue) {
n := c.Intn(3) + 1 // 1 to 3 items
arguments := make(fosite.Arguments, n)
for i := range arguments {
arguments[i] = randString(c)
}
*s = arguments
},
)
f.Fuzz(validSession)
const name = "fuzz" // value is irrelevant
ctx := context.Background()
secrets := fake.NewSimpleClientset().CoreV1().Secrets(name)
storage := New(secrets)
// issue a create using the fuzzed request to confirm that marshalling works
err = storage.CreateAuthorizeCodeSession(ctx, name, validSession.Request)
require.NoError(t, err)
// retrieve a copy of the fuzzed request from storage to confirm that unmarshalling works
newRequest, err := storage.GetAuthorizeCodeSession(ctx, name, nil)
require.NoError(t, err)
// the fuzzed request and the copy from storage should be exactly the same
require.Equal(t, validSession.Request, newRequest)
secretList, err := secrets.List(ctx, metav1.ListOptions{})
require.NoError(t, err)
require.Len(t, secretList.Items, 1)
authorizeCodeSessionJSONFromStorage := string(secretList.Items[0].Data["pinniped-storage-data"])
// set these to match CreateAuthorizeCodeSession so that .JSONEq works
validSession.Active = true
validSession.Version = "1"
validSessionJSONBytes, err := json.MarshalIndent(validSession, "", "\t")
require.NoError(t, err)
authorizeCodeSessionJSONFromFuzzing := string(validSessionJSONBytes)
// the fuzzed session and storage session should have identical JSON
require.JSONEq(t, authorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromStorage)
// while the fuzzer will panic if AuthorizeRequest changes in a way that cannot be fuzzed,
// if it adds a new field that can be fuzzed, this check will fail
// thus if AuthorizeRequest changes, we will detect it here (though we could possibly miss an omitempty field)
require.Equal(t, ExpectedAuthorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromFuzzing)
}

View File

@ -0,0 +1,34 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package fositestorage
import (
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"go.pinniped.dev/internal/constable"
)
const (
ErrInvalidRequestType = constable.Error("requester must be of type fosite.Request")
ErrInvalidClientType = constable.Error("requester's client must be of type fosite.DefaultOpenIDConnectClient")
ErrInvalidSessionType = constable.Error("requester's session must be of type openid.DefaultSession")
)
func ValidateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) {
request, ok1 := requester.(*fosite.Request)
if !ok1 {
return nil, ErrInvalidRequestType
}
_, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient)
if !ok2 {
return nil, ErrInvalidClientType
}
_, ok3 := request.Session.(*openid.DefaultSession)
if !ok3 {
return nil, ErrInvalidSessionType
}
return request, nil
}

View File

@ -0,0 +1,124 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package openidconnect
import (
"context"
"fmt"
"strings"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"k8s.io/apimachinery/pkg/api/errors"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
"go.pinniped.dev/internal/constable"
"go.pinniped.dev/internal/crud"
"go.pinniped.dev/internal/fositestorage"
)
const (
ErrInvalidOIDCRequestVersion = constable.Error("oidc request data has wrong version")
ErrInvalidOIDCRequestData = constable.Error("oidc request data must be present")
ErrMalformedAuthorizationCode = constable.Error("malformed authorization code")
oidcStorageVersion = "1"
)
var _ openid.OpenIDConnectRequestStorage = &openIDConnectRequestStorage{}
type openIDConnectRequestStorage struct {
storage crud.Storage
}
type session struct {
Request *fosite.Request `json:"request"`
Version string `json:"version"`
}
func New(secrets corev1client.SecretInterface) openid.OpenIDConnectRequestStorage {
return &openIDConnectRequestStorage{storage: crud.New("oidc", secrets)}
}
func (a *openIDConnectRequestStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error {
signature, err := getSignature(authcode)
if err != nil {
return err
}
request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester)
if err != nil {
return err
}
_, err = a.storage.Create(ctx, signature, &session{Request: request, Version: oidcStorageVersion})
return err
}
func (a *openIDConnectRequestStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, _ fosite.Requester) (fosite.Requester, error) {
signature, err := getSignature(authcode)
if err != nil {
return nil, err
}
session, _, err := a.getSession(ctx, signature)
if err != nil {
return nil, err
}
return session.Request, err
}
func (a *openIDConnectRequestStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error {
signature, err := getSignature(authcode)
if err != nil {
return err
}
return a.storage.Delete(ctx, signature)
}
func (a *openIDConnectRequestStorage) getSession(ctx context.Context, signature string) (*session, string, error) {
session := newValidEmptyOIDCSession()
rv, err := a.storage.Get(ctx, signature, session)
if errors.IsNotFound(err) {
return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error())
}
if err != nil {
return nil, "", fmt.Errorf("failed to get oidc session for %s: %w", signature, err)
}
if version := session.Version; version != oidcStorageVersion {
return nil, "", fmt.Errorf("%w: oidc session for %s has version %s instead of %s",
ErrInvalidOIDCRequestVersion, signature, version, oidcStorageVersion)
}
if session.Request.ID == "" {
return nil, "", fmt.Errorf("malformed oidc session for %s: %w", signature, ErrInvalidOIDCRequestData)
}
return session, rv, nil
}
func newValidEmptyOIDCSession() *session {
return &session{
Request: &fosite.Request{
Client: &fosite.DefaultOpenIDConnectClient{},
Session: &openid.DefaultSession{},
},
}
}
func getSignature(authorizationCode string) (string, error) {
split := strings.Split(authorizationCode, ".")
if len(split) != 2 {
return "", ErrMalformedAuthorizationCode
}
return split[1], nil
}

View File

@ -0,0 +1,209 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package openidconnect
import (
"context"
"net/url"
"testing"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/kubernetes/fake"
coretesting "k8s.io/client-go/testing"
)
const namespace = "test-ns"
func TestOpenIdConnectStorage(t *testing.T) {
ctx := context.Background()
secretsGVR := schema.GroupVersionResource{
Group: "",
Version: "v1",
Resource: "secrets",
}
wantActions := []coretesting.Action{
coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "oidc",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/oidc",
}),
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"),
coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"),
}
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
request := &fosite.Request{
ID: "abcd-1",
RequestedAt: time.Time{},
Client: &fosite.DefaultOpenIDConnectClient{
DefaultClient: &fosite.DefaultClient{
ID: "pinny",
Secret: nil,
RedirectURIs: nil,
GrantTypes: nil,
ResponseTypes: nil,
Scopes: nil,
Audience: nil,
Public: true,
},
JSONWebKeysURI: "where",
JSONWebKeys: nil,
TokenEndpointAuthMethod: "something",
RequestURIs: nil,
RequestObjectSigningAlgorithm: "",
TokenEndpointAuthSigningAlgorithm: "",
},
RequestedScope: nil,
GrantedScope: nil,
Form: url.Values{"key": []string{"val"}},
Session: &openid.DefaultSession{
Claims: nil,
Headers: nil,
ExpiresAt: nil,
Username: "snorlax",
Subject: "panda",
},
RequestedAudience: nil,
GrantedAudience: nil,
}
err := storage.CreateOpenIDConnectSession(ctx, "fancy-code.fancy-signature", request)
require.NoError(t, err)
newRequest, err := storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil)
require.NoError(t, err)
require.Equal(t, request, newRequest)
err = storage.DeleteOpenIDConnectSession(ctx, "fancy-code.fancy-signature")
require.NoError(t, err)
require.Equal(t, wantActions, client.Actions())
}
func TestGetNotFound(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
_, notFoundErr := storage.GetOpenIDConnectSession(ctx, "authcode.non-existent-signature", nil)
require.EqualError(t, notFoundErr, "not_found")
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
}
func TestWrongVersion(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
secret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "oidc",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/oidc",
}
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
require.NoError(t, err)
_, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil)
require.EqualError(t, err, "oidc request data has wrong version: oidc session for fancy-signature has version not-the-right-version instead of 1")
}
func TestNilSessionRequest(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
secret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "oidc",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/oidc",
}
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
require.NoError(t, err)
_, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil)
require.EqualError(t, err, "malformed oidc session for fancy-signature: oidc request data must be present")
}
func TestCreateWithNilRequester(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", nil)
require.EqualError(t, err, "requester must be of type fosite.Request")
}
func TestCreateWithWrongRequesterDataTypes(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
request := &fosite.Request{
Session: nil,
Client: &fosite.DefaultOpenIDConnectClient{},
}
err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request)
require.EqualError(t, err, "requester's session must be of type openid.DefaultSession")
request = &fosite.Request{
Session: &openid.DefaultSession{},
Client: nil,
}
err = storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request)
require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient")
}
func TestAuthcodeHasNoDot(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
err := storage.CreateOpenIDConnectSession(ctx, "all-one-part", nil)
require.EqualError(t, err, "malformed authorization code")
}

View File

@ -0,0 +1,98 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package pkce
import (
"context"
"fmt"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/handler/pkce"
"k8s.io/apimachinery/pkg/api/errors"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
"go.pinniped.dev/internal/constable"
"go.pinniped.dev/internal/crud"
"go.pinniped.dev/internal/fositestorage"
)
const (
ErrInvalidPKCERequestVersion = constable.Error("pkce request data has wrong version")
ErrInvalidPKCERequestData = constable.Error("pkce request data must be present")
pkceStorageVersion = "1"
)
var _ pkce.PKCERequestStorage = &pkceStorage{}
type pkceStorage struct {
storage crud.Storage
}
type session struct {
Request *fosite.Request `json:"request"`
Version string `json:"version"`
}
func New(secrets corev1client.SecretInterface) pkce.PKCERequestStorage {
return &pkceStorage{storage: crud.New("pkce", secrets)}
}
func (a *pkceStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error {
request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester)
if err != nil {
return err
}
_, err = a.storage.Create(ctx, signature, &session{Request: request, Version: pkceStorageVersion})
return err
}
func (a *pkceStorage) GetPKCERequestSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) {
session, _, err := a.getSession(ctx, signature)
if err != nil {
return nil, err
}
return session.Request, err
}
func (a *pkceStorage) DeletePKCERequestSession(ctx context.Context, signature string) error {
return a.storage.Delete(ctx, signature)
}
func (a *pkceStorage) getSession(ctx context.Context, signature string) (*session, string, error) {
session := newValidEmptyPKCESession()
rv, err := a.storage.Get(ctx, signature, session)
if errors.IsNotFound(err) {
return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error())
}
if err != nil {
return nil, "", fmt.Errorf("failed to get pkce session for %s: %w", signature, err)
}
if version := session.Version; version != pkceStorageVersion {
return nil, "", fmt.Errorf("%w: pkce session for %s has version %s instead of %s",
ErrInvalidPKCERequestVersion, signature, version, pkceStorageVersion)
}
if session.Request.ID == "" {
return nil, "", fmt.Errorf("malformed pkce session for %s: %w", signature, ErrInvalidPKCERequestData)
}
return session, rv, nil
}
func newValidEmptyPKCESession() *session {
return &session{
Request: &fosite.Request{
Client: &fosite.DefaultOpenIDConnectClient{},
Session: &openid.DefaultSession{},
},
}
}

View File

@ -0,0 +1,199 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package pkce
import (
"context"
"net/url"
"testing"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/kubernetes/fake"
coretesting "k8s.io/client-go/testing"
)
const namespace = "test-ns"
func TestPKCEStorage(t *testing.T) {
ctx := context.Background()
secretsGVR := schema.GroupVersionResource{
Group: "",
Version: "v1",
Resource: "secrets",
}
wantActions := []coretesting.Action{
coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "pkce",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/pkce",
}),
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"),
coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"),
}
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
request := &fosite.Request{
ID: "abcd-1",
RequestedAt: time.Time{},
Client: &fosite.DefaultOpenIDConnectClient{
DefaultClient: &fosite.DefaultClient{
ID: "pinny",
Secret: nil,
RedirectURIs: nil,
GrantTypes: nil,
ResponseTypes: nil,
Scopes: nil,
Audience: nil,
Public: true,
},
JSONWebKeysURI: "where",
JSONWebKeys: nil,
TokenEndpointAuthMethod: "something",
RequestURIs: nil,
RequestObjectSigningAlgorithm: "",
TokenEndpointAuthSigningAlgorithm: "",
},
RequestedScope: nil,
GrantedScope: nil,
Form: url.Values{"key": []string{"val"}},
Session: &openid.DefaultSession{
Claims: nil,
Headers: nil,
ExpiresAt: nil,
Username: "snorlax",
Subject: "panda",
},
RequestedAudience: nil,
GrantedAudience: nil,
}
err := storage.CreatePKCERequestSession(ctx, "fancy-signature", request)
require.NoError(t, err)
newRequest, err := storage.GetPKCERequestSession(ctx, "fancy-signature", nil)
require.NoError(t, err)
require.Equal(t, request, newRequest)
err = storage.DeletePKCERequestSession(ctx, "fancy-signature")
require.NoError(t, err)
require.Equal(t, wantActions, client.Actions())
}
func TestGetNotFound(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
_, notFoundErr := storage.GetPKCERequestSession(ctx, "non-existent-signature", nil)
require.EqualError(t, notFoundErr, "not_found")
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
}
func TestWrongVersion(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
secret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "pkce",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/pkce",
}
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
require.NoError(t, err)
_, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil)
require.EqualError(t, err, "pkce request data has wrong version: pkce session for fancy-signature has version not-the-right-version instead of 1")
}
func TestNilSessionRequest(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
secret := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4",
ResourceVersion: "",
Labels: map[string]string{
"storage.pinniped.dev": "pkce",
},
},
Data: map[string][]byte{
"pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`),
"pinniped-storage-version": []byte("1"),
},
Type: "storage.pinniped.dev/pkce",
}
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
require.NoError(t, err)
_, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil)
require.EqualError(t, err, "malformed pkce session for fancy-signature: pkce request data must be present")
}
func TestCreateWithNilRequester(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", nil)
require.EqualError(t, err, "requester must be of type fosite.Request")
}
func TestCreateWithWrongRequesterDataTypes(t *testing.T) {
ctx := context.Background()
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets(namespace)
storage := New(secrets)
request := &fosite.Request{
Session: nil,
Client: &fosite.DefaultOpenIDConnectClient{},
}
err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request)
require.EqualError(t, err, "requester's session must be of type openid.DefaultSession")
request = &fosite.Request{
Session: &openid.DefaultSession{},
Client: nil,
}
err = storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request)
require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient")
}

View File

@ -0,0 +1,6 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package mockupstreamoidcidentityprovider
//go:generate go run -v github.com/golang/mock/mockgen -destination=mockupstreamoidcidentityprovider.go -package=mockupstreamoidcidentityprovider -copyright_file=../../../hack/header.txt go.pinniped.dev/internal/oidc/provider UpstreamOIDCIdentityProviderI

View File

@ -0,0 +1,159 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Code generated by MockGen. DO NOT EDIT.
// Source: go.pinniped.dev/internal/oidc/provider (interfaces: UpstreamOIDCIdentityProviderI)
// Package mockupstreamoidcidentityprovider is a generated GoMock package.
package mockupstreamoidcidentityprovider
import (
context "context"
gomock "github.com/golang/mock/gomock"
nonce "go.pinniped.dev/pkg/oidcclient/nonce"
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
pkce "go.pinniped.dev/pkg/oidcclient/pkce"
oauth2 "golang.org/x/oauth2"
url "net/url"
reflect "reflect"
)
// MockUpstreamOIDCIdentityProviderI is a mock of UpstreamOIDCIdentityProviderI interface
type MockUpstreamOIDCIdentityProviderI struct {
ctrl *gomock.Controller
recorder *MockUpstreamOIDCIdentityProviderIMockRecorder
}
// MockUpstreamOIDCIdentityProviderIMockRecorder is the mock recorder for MockUpstreamOIDCIdentityProviderI
type MockUpstreamOIDCIdentityProviderIMockRecorder struct {
mock *MockUpstreamOIDCIdentityProviderI
}
// NewMockUpstreamOIDCIdentityProviderI creates a new mock instance
func NewMockUpstreamOIDCIdentityProviderI(ctrl *gomock.Controller) *MockUpstreamOIDCIdentityProviderI {
mock := &MockUpstreamOIDCIdentityProviderI{ctrl: ctrl}
mock.recorder = &MockUpstreamOIDCIdentityProviderIMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityProviderIMockRecorder {
return m.recorder
}
// ExchangeAuthcodeAndValidateTokens mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(oidctypes.Token)
ret1, _ := ret[1].(map[string]interface{})
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3, arg4)
}
// GetAuthorizationURL mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) GetAuthorizationURL() *url.URL {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthorizationURL")
ret0, _ := ret[0].(*url.URL)
return ret0
}
// GetAuthorizationURL indicates an expected call of GetAuthorizationURL
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetAuthorizationURL() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationURL", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetAuthorizationURL))
}
// GetClientID mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) GetClientID() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClientID")
ret0, _ := ret[0].(string)
return ret0
}
// GetClientID indicates an expected call of GetClientID
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetClientID() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientID", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetClientID))
}
// GetGroupsClaim mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) GetGroupsClaim() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupsClaim")
ret0, _ := ret[0].(string)
return ret0
}
// GetGroupsClaim indicates an expected call of GetGroupsClaim
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetGroupsClaim() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetGroupsClaim))
}
// GetName mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) GetName() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetName")
ret0, _ := ret[0].(string)
return ret0
}
// GetName indicates an expected call of GetName
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetName() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetName))
}
// GetScopes mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) GetScopes() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetScopes")
ret0, _ := ret[0].([]string)
return ret0
}
// GetScopes indicates an expected call of GetScopes
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetScopes() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetScopes))
}
// GetUsernameClaim mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) GetUsernameClaim() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetUsernameClaim")
ret0, _ := ret[0].(string)
return ret0
}
// GetUsernameClaim indicates an expected call of GetUsernameClaim
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetUsernameClaim() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsernameClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetUsernameClaim))
}
// ValidateToken mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2)
ret0, _ := ret[0].(oidctypes.Token)
ret1, _ := ret[1].(map[string]interface{})
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// ValidateToken indicates an expected call of ValidateToken
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateToken(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateToken), arg0, arg1, arg2)
}

View File

@ -9,14 +9,13 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gorilla/securecookie"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt" "github.com/ory/fosite/token/jwt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
@ -24,42 +23,15 @@ import (
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
const (
// Just in case we need to make a breaking change to the format of the upstream state param,
// we are including a format version number. This gives the opportunity for a future version of Pinniped
// to have the consumer of this format decide to reject versions that it doesn't understand.
upstreamStateParamFormatVersion = "1"
// The `name` passed to the encoder for encoding the upstream state param value. This name is short
// because it will be encoded into the upstream state param value and we're trying to keep that small.
upstreamStateParamEncodingName = "s"
// The name of the browser cookie which shall hold our CSRF value.
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
csrfCookieName = "__Host-pinniped-csrf"
// The `name` passed to the encoder for encoding and decoding the CSRF cookie contents.
csrfCookieEncodingName = "csrf"
)
type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider
}
// This is the encoding side of the securecookie.Codec interface.
type Encoder interface {
Encode(name string, value interface{}) (string, error)
}
func NewHandler( func NewHandler(
issuer string, downstreamIssuer string,
idpListGetter IDPListGetter, idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
upstreamStateEncoder Encoder, upstreamStateEncoder oidc.Encoder,
cookieCodec securecookie.Codec, cookieCodec oidc.Codec,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost && r.Method != http.MethodGet { if r.Method != http.MethodPost && r.Method != http.MethodGet {
@ -69,11 +41,7 @@ func NewHandler(
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method)
} }
csrfFromCookie, err := readCSRFCookie(r, cookieCodec) csrfFromCookie := readCSRFCookie(r, cookieCodec)
if err != nil {
plog.InfoErr("error reading CSRF cookie", err)
return err
}
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r) authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r)
if err != nil { if err != nil {
@ -116,15 +84,22 @@ func NewHandler(
} }
upstreamOAuthConfig := oauth2.Config{ upstreamOAuthConfig := oauth2.Config{
ClientID: upstreamIDP.ClientID, ClientID: upstreamIDP.GetClientID(),
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: upstreamIDP.AuthorizationURL.String(), AuthURL: upstreamIDP.GetAuthorizationURL().String(),
}, },
RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.Name), RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer),
Scopes: upstreamIDP.Scopes, Scopes: upstreamIDP.GetScopes(),
} }
encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) encodedStateParamValue, err := upstreamStateParam(
authorizeRequester,
upstreamIDP.GetName(),
nonceValue,
csrfValue,
pkceValue,
upstreamStateEncoder,
)
if err != nil { if err != nil {
plog.Error("authorize upstream state param error", err) plog.Error("authorize upstream state param error", err)
return err return err
@ -154,20 +129,23 @@ func NewHandler(
}) })
} }
func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { func readCSRFCookie(r *http.Request, codec oidc.Codec) csrftoken.CSRFToken {
receivedCSRFCookie, err := r.Cookie(csrfCookieName) receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
if err != nil { if err != nil {
// Error means that the cookie was not found // Error means that the cookie was not found
return "", nil return ""
} }
var csrfFromCookie csrftoken.CSRFToken var csrfFromCookie csrftoken.CSRFToken
err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) err = codec.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil { if err != nil {
return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err) // We can ignore any errors and just make a new cookie. Hopefully this will
// make the user experience better if, for example, the server rotated
// cookie signing keys and then a user submitted a very old cookie.
return ""
} }
return csrfFromCookie, nil return csrfFromCookie
} }
func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
@ -178,7 +156,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
} }
} }
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (provider.UpstreamOIDCIdentityProviderI, error) {
allUpstreamIDPs := idpListGetter.GetIDPList() allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 { if len(allUpstreamIDPs) == 0 {
return nil, httperr.New( return nil, httperr.New(
@ -191,7 +169,7 @@ func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdent
"Too many upstream providers are configured (support for multiple upstreams is not yet implemented)", "Too many upstream providers are configured (support for multiple upstreams is not yet implemented)",
) )
} }
return &allUpstreamIDPs[0], nil return allUpstreamIDPs[0], nil
} }
func generateValues( func generateValues(
@ -214,48 +192,42 @@ func generateValues(
return csrfValue, nonceValue, pkceValue, nil return csrfValue, nonceValue, pkceValue, nil
} }
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
type upstreamStateParamData struct {
AuthParams string `json:"p"`
Nonce nonce.Nonce `json:"n"`
CSRFToken csrftoken.CSRFToken `json:"c"`
PKCECode pkce.Code `json:"k"`
StateParamFormatVersion string `json:"v"`
}
func upstreamStateParam( func upstreamStateParam(
authorizeRequester fosite.AuthorizeRequester, authorizeRequester fosite.AuthorizeRequester,
upstreamName string,
nonceValue nonce.Nonce, nonceValue nonce.Nonce,
csrfValue csrftoken.CSRFToken, csrfValue csrftoken.CSRFToken,
pkceValue pkce.Code, pkceValue pkce.Code,
encoder Encoder, encoder oidc.Encoder,
) (string, error) { ) (string, error) {
stateParamData := upstreamStateParamData{ stateParamData := oidc.UpstreamStateParamData{
AuthParams: authorizeRequester.GetRequestForm().Encode(), AuthParams: authorizeRequester.GetRequestForm().Encode(),
UpstreamName: upstreamName,
Nonce: nonceValue, Nonce: nonceValue,
CSRFToken: csrfValue, CSRFToken: csrfValue,
PKCECode: pkceValue, PKCECode: pkceValue,
StateParamFormatVersion: upstreamStateParamFormatVersion, FormatVersion: oidc.UpstreamStateParamFormatVersion,
} }
encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData) encodedStateParamValue, err := encoder.Encode(oidc.UpstreamStateParamEncodingName, stateParamData)
if err != nil { if err != nil {
return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err) return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err)
} }
return encodedStateParamValue, nil return encodedStateParamValue, nil
} }
func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec oidc.Codec) error {
encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue) encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err)
} }
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: csrfCookieName, Name: oidc.CSRFCookieName,
Value: encodedCSRFValue, Value: encodedCSRFValue,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
Secure: true, Secure: true,
Path: "/",
}) })
return nil return nil

View File

@ -21,6 +21,7 @@ import (
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
@ -28,6 +29,7 @@ import (
func TestAuthorizationEndpoint(t *testing.T) { func TestAuthorizationEndpoint(t *testing.T) {
const ( const (
downstreamIssuer = "https://my-downstream-issuer.com/some-path"
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback" downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback"
) )
@ -113,21 +115,19 @@ func TestAuthorizationEndpoint(t *testing.T) {
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
require.NoError(t, err) require.NoError(t, err)
upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ upstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "some-idp", Name: "some-idp",
ClientID: "some-client-id", ClientID: "some-client-id",
AuthorizationURL: *upstreamAuthURL, AuthorizationURL: *upstreamAuthURL,
Scopes: []string{"scope1", "scope2"}, Scopes: []string{"scope1", "scope2"},
} }
issuer := "https://my-issuer.com/some-path" // Configure fosite the same way that the production code would, using NullStorage to turn off storage.
// Configure fosite the same way that the production code would, except use in-memory storage.
oauthStore := oidc.NullStorage{} oauthStore := oidc.NullStorage{}
hmacSecret := []byte("some secret - must have at least 32 bytes") hmacSecret := []byte("some secret - must have at least 32 bytes")
var signingKeyIsUnused *ecdsa.PrivateKey var signingKeyIsUnused *ecdsa.PrivateKey
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(issuer, oauthStore, hmacSecret, signingKeyIsUnused) oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret, signingKeyIsUnused)
happyCSRF := "test-csrf" happyCSRF := "test-csrf"
happyPKCE := "test-pkce" happyPKCE := "test-pkce"
@ -206,14 +206,19 @@ func TestAuthorizationEndpoint(t *testing.T) {
return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides)) return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides))
} }
expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string { expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride, upstreamNameOverride string) string {
csrf := happyCSRF csrf := happyCSRF
if csrfValueOverride != "" { if csrfValueOverride != "" {
csrf = csrfValueOverride csrf = csrfValueOverride
} }
upstreamName := upstreamOIDCIdentityProvider.Name
if upstreamNameOverride != "" {
upstreamName = upstreamNameOverride
}
encoded, err := happyStateEncoder.Encode("s", encoded, err := happyStateEncoder.Encode("s",
expectedUpstreamStateParamFormat{ oidctestutil.ExpectedUpstreamStateParamFormat{
P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)),
U: upstreamName,
N: happyNonce, N: happyNonce,
C: csrf, C: csrf,
K: happyPKCE, K: happyPKCE,
@ -234,7 +239,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-idp", "redirect_uri": downstreamIssuer + "/callback",
}) })
} }
@ -250,8 +255,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF func() (csrftoken.CSRFToken, error) generateCSRF func() (csrftoken.CSRFToken, error)
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
stateEncoder securecookie.Codec stateEncoder oidc.Codec
cookieEncoder securecookie.Codec cookieEncoder oidc.Codec
method string method string
path string path string
contentType string contentType string
@ -271,8 +276,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
tests := []testCase{ tests := []testCase{
{ {
name: "happy path using GET without a CSRF cookie", name: "happy path using GET without a CSRF cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -283,14 +288,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantCSRFValueInCookieHeader: happyCSRF, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
{ {
name: "happy path using GET with a CSRF cookie", name: "happy path using GET with a CSRF cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -298,17 +303,17 @@ func TestAuthorizationEndpoint(t *testing.T) {
cookieEncoder: happyCookieEncoder, cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue, csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + " ",
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)), wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
{ {
name: "happy path using POST", name: "happy path using POST",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -322,13 +327,33 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantContentType: "", wantContentType: "",
wantBodyString: "", wantBodyString: "",
wantCSRFValueInCookieHeader: happyCSRF, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
}, },
{
name: "error while decoding CSRF cookie just generates a new cookie and succeeds as usual",
issuer: downstreamIssuer,
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8",
// Generated a new CSRF cookie and set it in the response.
wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")),
wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true,
},
{ {
name: "happy path when downstream redirect uri matches what is configured for client except for the port number", name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -343,14 +368,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantCSRFValueInCookieHeader: happyCSRF, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{
"redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
}, "")), }, "", "")),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
{ {
name: "downstream redirect uri does not match what is configured for client", name: "downstream redirect uri does not match what is configured for client",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -366,8 +391,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "downstream client does not exist", name: "downstream client does not exist",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -381,8 +406,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "response type is unsupported", name: "response type is unsupported",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -397,8 +422,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "downstream scopes do not match what is configured for client", name: "downstream scopes do not match what is configured for client",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -413,8 +438,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing response type in request", name: "missing response type in request",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -429,8 +454,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing client id in request", name: "missing client id in request",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -444,8 +469,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -460,8 +485,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -476,8 +501,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -492,8 +517,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -510,8 +535,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
// This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running // This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running
// through that part of the fosite library. // through that part of the fosite library.
name: "prompt param is not allowed to have none and another legal value at the same time", name: "prompt param is not allowed to have none and another legal value at the same time",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -526,8 +551,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "OIDC validations are skipped when the openid scope was not requested", name: "OIDC validations are skipped when the openid scope was not requested",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -540,15 +565,15 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantCSRFValueInCookieHeader: happyCSRF, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam( wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(
map[string]string{"prompt": "none login", "scope": "email"}, "", map[string]string{"prompt": "none login", "scope": "email"}, "", "",
)), )),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
{ {
name: "state does not have enough entropy", name: "state does not have enough entropy",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -563,8 +588,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while encoding upstream state param", name: "error while encoding upstream state param",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -578,8 +603,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while encoding CSRF cookie value for new cookie", name: "error while encoding CSRF cookie value for new cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -593,8 +618,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while generating CSRF token", name: "error while generating CSRF token",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -608,8 +633,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while generating nonce", name: "error while generating nonce",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
@ -623,8 +648,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while generating PKCE", name: "error while generating PKCE",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -636,26 +661,10 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantContentType: "text/plain; charset=utf-8", wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Internal Server Error: error generating PKCE param\n", wantBodyString: "Internal Server Error: error generating PKCE param\n",
}, },
{
name: "error while decoding CSRF cookie",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusUnprocessableEntity,
wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Unprocessable Entity: error reading CSRF cookie\n",
},
{ {
name: "no upstream providers are configured", name: "no upstream providers are configured",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(), // empty idpListGetter: oidctestutil.NewIDPListGetter(), // empty
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity, wantStatus: http.StatusUnprocessableEntity,
@ -664,8 +673,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "too many upstream providers are configured", name: "too many upstream providers are configured",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity, wantStatus: http.StatusUnprocessableEntity,
@ -674,8 +683,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "PUT is a bad method", name: "PUT is a bad method",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPut, method: http.MethodPut,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -684,8 +693,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "PATCH is a bad method", name: "PATCH is a bad method",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPatch, method: http.MethodPatch,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -694,8 +703,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "DELETE is a bad method", name: "DELETE is a bad method",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodDelete, method: http.MethodDelete,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -712,6 +721,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
} }
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String())
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
@ -742,7 +753,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
if test.wantCSRFValueInCookieHeader != "" { if test.wantCSRFValueInCookieHeader != "" {
require.Len(t, rsp.Header().Values("Set-Cookie"), 1) require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
actualCookie := rsp.Header().Get("Set-Cookie") actualCookie := rsp.Header().Get("Set-Cookie")
regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); HttpOnly; Secure; SameSite=Strict") regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); Path=/; HttpOnly; Secure; SameSite=Strict")
submatches := regex.FindStringSubmatch(actualCookie) submatches := regex.FindStringSubmatch(actualCookie)
require.Len(t, submatches, 2) require.Len(t, submatches, 2)
captured := submatches[1] captured := submatches[1]
@ -772,13 +783,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
// Call the setter to change the upstream IDP settings. // Call the setter to change the upstream IDP settings.
newProviderSettings := provider.UpstreamOIDCIdentityProvider{ newProviderSettings := oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "some-other-idp", Name: "some-other-idp",
ClientID: "some-other-client-id", ClientID: "some-other-client-id",
AuthorizationURL: *upstreamAuthURL, AuthorizationURL: *upstreamAuthURL,
Scopes: []string{"other-scope1", "other-scope2"}, Scopes: []string{"other-scope1", "other-scope2"},
} }
test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{newProviderSettings}) test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(&newProviderSettings)})
// Update the expectations of the test case to match the new upstream IDP settings. // Update the expectations of the test case to match the new upstream IDP settings.
test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(), test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(),
@ -787,11 +798,11 @@ func TestAuthorizationEndpoint(t *testing.T) {
"access_type": "offline", "access_type": "offline",
"scope": "other-scope1 other-scope2", "scope": "other-scope1 other-scope2",
"client_id": "some-other-client-id", "client_id": "some-other-client-id",
"state": expectedUpstreamStateParam(nil, ""), "state": expectedUpstreamStateParam(nil, "", newProviderSettings.Name),
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-other-idp", "redirect_uri": downstreamIssuer + "/callback",
}, },
) )
test.wantBodyString = fmt.Sprintf(`<a href="%s">Found</a>.%s`, test.wantBodyString = fmt.Sprintf(`<a href="%s">Found</a>.%s`,
@ -807,20 +818,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
}) })
} }
// Declare a separate type from the production code to ensure that the state param's contents was serialized
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
// assertions about the redirect URL in this test.
type expectedUpstreamStateParamFormat struct {
P string `json:"p"`
N string `json:"n"`
C string `json:"c"`
K string `json:"k"`
V string `json:"v"`
}
type errorReturningEncoder struct { type errorReturningEncoder struct {
securecookie.Codec oidc.Codec
} }
func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) { func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) {
@ -843,7 +842,7 @@ func requireEqualContentType(t *testing.T, actual string, expected string) {
require.Equal(t, actualContentTypeParams, expectedContentTypeParams) require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
} }
func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) { func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) {
t.Helper() t.Helper()
actualLocationURL, err := url.Parse(actualURL) actualLocationURL, err := url.Parse(actualURL)
require.NoError(t, err) require.NoError(t, err)
@ -852,13 +851,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL
expectedQueryStateParam := expectedLocationURL.Query().Get("state") expectedQueryStateParam := expectedLocationURL.Query().Get("state")
require.NotEmpty(t, expectedQueryStateParam) require.NotEmpty(t, expectedQueryStateParam)
var expectedDecodedStateParam expectedUpstreamStateParamFormat var expectedDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
require.NoError(t, err) require.NoError(t, err)
actualQueryStateParam := actualLocationURL.Query().Get("state") actualQueryStateParam := actualLocationURL.Query().Get("state")
require.NotEmpty(t, actualQueryStateParam) require.NotEmpty(t, actualQueryStateParam)
var actualDecodedStateParam expectedUpstreamStateParamFormat var actualDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
require.NoError(t, err) require.NoError(t, err)
@ -871,10 +870,20 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore
require.NoError(t, err) require.NoError(t, err)
expectedLocationURL, err := url.Parse(expectedURL) expectedLocationURL, err := url.Parse(expectedURL)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme) require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme,
require.Equal(t, expectedLocationURL.User, actualLocationURL.User) "schemes were not equal: expected %s but got %s", expectedURL, actualURL,
require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host) )
require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path) require.Equal(t, expectedLocationURL.User, actualLocationURL.User,
"users were not equal: expected %s but got %s", expectedURL, actualURL,
)
require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host,
"hosts were not equal: expected %s but got %s", expectedURL, actualURL,
)
require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path,
"paths were not equal: expected %s but got %s", expectedURL, actualURL,
)
expectedLocationQuery := expectedLocationURL.Query() expectedLocationQuery := expectedLocationURL.Query()
actualLocationQuery := actualLocationURL.Query() actualLocationQuery := actualLocationURL.Query()
@ -886,9 +895,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore
} }
require.Equal(t, expectedLocationQuery, actualLocationQuery) require.Equal(t, expectedLocationQuery, actualLocationQuery)
} }
func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider()
idpProvider.SetIDPList(upstreamOIDCIdentityProviders)
return idpProvider
}

View File

@ -0,0 +1,306 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package callback provides a handler for the OIDC callback endpoint.
package callback
import (
"crypto/subtle"
"fmt"
"net/http"
"net/url"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog"
)
const (
// The name of the issuer claim specified in the OIDC spec.
idTokenIssuerClaim = "iss"
// The name of the subject claim specified in the OIDC spec.
idTokenSubjectClaim = "sub"
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
// ID token if the upstream OIDC IDP did not tell us to use another claim.
defaultUpstreamUsernameClaim = idTokenSubjectClaim
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
downstreamGroupsClaim = "groups"
)
func NewHandler(
idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider,
stateDecoder, cookieDecoder oidc.Decoder,
redirectURI string,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil {
return err
}
upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter)
if upstreamIDPConfig == nil {
plog.Warning("upstream provider not found")
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
}
downstreamAuthParams, err := url.ParseQuery(state.AuthParams)
if err != nil {
plog.Error("error reading state downstream auth params", err)
return httperr.New(http.StatusBadRequest, "error reading state downstream auth params")
}
// Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest().
reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams}
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest)
if err != nil {
plog.Error("error using state downstream auth params", err)
return httperr.New(http.StatusBadRequest, "error using state downstream auth params")
}
// Grant the openid scope only if it was requested.
grantOpenIDScopeIfRequested(authorizeRequester)
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
r.Context(),
authcode(r),
state.PKCECode,
state.Nonce,
redirectURI,
)
if err != nil {
plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName())
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
}
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
if err != nil {
return err
}
groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
if err != nil {
return err
}
openIDSession := makeDownstreamSession(username, groups)
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
if err != nil {
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
return httperr.Wrap(http.StatusInternalServerError, "error while generating and saving authcode", err)
}
oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder)
return nil
})
}
func authcode(r *http.Request) string {
return r.FormValue("code")
}
func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
if r.Method != http.MethodGet {
return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
}
csrfValue, err := readCSRFCookie(r, cookieDecoder)
if err != nil {
plog.InfoErr("error reading CSRF cookie", err)
return nil, err
}
if authcode(r) == "" {
plog.Info("code param not found")
return nil, httperr.New(http.StatusBadRequest, "code param not found")
}
if r.FormValue("state") == "" {
plog.Info("state param not found")
return nil, httperr.New(http.StatusBadRequest, "state param not found")
}
state, err := readState(r, stateDecoder)
if err != nil {
plog.InfoErr("error reading state", err)
return nil, err
}
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
plog.InfoErr("CSRF value does not match", err)
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
}
return state, nil
}
func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
for _, p := range idpListGetter.GetIDPList() {
if p.GetName() == upstreamName {
return p
}
}
return nil
}
func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
if err != nil {
// Error means that the cookie was not found
return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err)
}
var csrfFromCookie csrftoken.CSRFToken
err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil {
return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err)
}
return csrfFromCookie, nil
}
func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
var state oidc.UpstreamStateParamData
if err := stateDecoder.Decode(
oidc.UpstreamStateParamEncodingName,
r.FormValue("state"),
&state,
); err != nil {
return nil, httperr.New(http.StatusBadRequest, "error reading state")
}
if state.FormatVersion != oidc.UpstreamStateParamFormatVersion {
return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid")
}
return &state, nil
}
func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
for _, scope := range authorizeRequester.GetRequestedScopes() {
if scope == "openid" {
authorizeRequester.GrantScope(scope)
}
}
}
func getUsernameFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) (string, error) {
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
user := ""
if usernameClaim == "" {
// The spec says the "sub" claim is only unique per issuer, so by default when there is
// no specific username claim configured we will prepend the issuer string to make it globally unique.
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
if upstreamIssuer == "" {
plog.Warning(
"issuer claim in upstream ID token missing",
"upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer,
)
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
}
upstreamIssuerAsString, ok := upstreamIssuer.(string)
if !ok {
plog.Warning(
"issuer claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer,
)
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
}
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
usernameClaim = defaultUpstreamUsernameClaim
}
usernameAsInterface, ok := idTokenClaims[usernameClaim]
if !ok {
plog.Warning(
"no username claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim,
)
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
}
username, ok := usernameAsInterface.(string)
if !ok {
plog.Warning(
"username claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim,
)
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
}
return fmt.Sprintf("%s%s", user, username), nil
}
func getGroupsFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) ([]string, error) {
groupsClaim := upstreamIDPConfig.GetGroupsClaim()
if groupsClaim == "" {
return nil, nil
}
groupsAsInterface, ok := idTokenClaims[groupsClaim]
if !ok {
plog.Warning(
"no groups claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
"configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(),
"groupsClaim", groupsClaim,
)
return nil, httperr.New(http.StatusUnprocessableEntity, "no groups claim in upstream ID token")
}
groups, ok := groupsAsInterface.([]string)
if !ok {
plog.Warning(
"groups claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
"configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(),
"groupsClaim", groupsClaim,
)
return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format")
}
return groups, nil
}
func makeDownstreamSession(username string, groups []string) *openid.DefaultSession {
now := time.Now().UTC()
openIDSession := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: username,
RequestedAt: now,
AuthTime: now,
},
}
if groups != nil {
openIDSession.Claims.Extra = map[string]interface{}{
downstreamGroupsClaim: groups,
}
}
return openIDSession
}

View File

@ -0,0 +1,859 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package callback
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"time"
"github.com/gorilla/securecookie"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/kubernetes/fake"
kubetesting "k8s.io/client-go/testing"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
)
const (
happyUpstreamIDPName = "upstream-idp-name"
upstreamIssuer = "https://my-upstream-issuer.com"
upstreamSubject = "abc123-some-guid"
upstreamUsername = "test-pinniped-username"
upstreamUsernameClaim = "the-user-claim"
upstreamGroupsClaim = "the-groups-claim"
happyUpstreamAuthcode = "upstream-auth-code"
happyUpstreamRedirectURI = "https://example.com/callback"
happyDownstreamState = "some-downstream-state-with-at-least-32-bytes"
happyDownstreamCSRF = "test-csrf"
happyDownstreamPKCE = "test-pkce"
happyDownstreamNonce = "test-nonce"
happyDownstreamStateVersion = "1"
downstreamIssuer = "https://my-downstream-issuer.com/path"
downstreamRedirectURI = "http://127.0.0.1/callback"
downstreamClientID = "pinniped-cli"
downstreamNonce = "some-nonce-value"
downstreamPKCEChallenge = "some-challenge"
downstreamPKCEChallengeMethod = "S256"
timeComparisonFudgeFactor = time.Second * 15
)
var (
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
happyDownstreamScopesRequested = []string{"openid", "profile", "email"}
happyDownstreamRequestParamsQuery = url.Values{
"response_type": []string{"code"},
"scope": []string{strings.Join(happyDownstreamScopesRequested, " ")},
"client_id": []string{downstreamClientID},
"state": []string{happyDownstreamState},
"nonce": []string{downstreamNonce},
"code_challenge": []string{downstreamPKCEChallenge},
"code_challenge_method": []string{downstreamPKCEChallengeMethod},
"redirect_uri": []string{downstreamRedirectURI},
}
happyDownstreamRequestParams = happyDownstreamRequestParamsQuery.Encode()
)
func TestCallbackEndpoint(t *testing.T) {
otherUpstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "other-upstream-idp-name",
ClientID: "other-some-client-id",
Scopes: []string{"other-scope1", "other-scope2"},
}
var stateEncoderHashKey = []byte("fake-hash-secret")
var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
var cookieEncoderHashKey = []byte("fake-hash-secret2")
var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES
require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey)
require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey)
var happyStateCodec = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey)
happyStateCodec.SetSerializer(securecookie.JSONEncoder{})
var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
happyCookieCodec.SetSerializer(securecookie.JSONEncoder{})
happyState := happyUpstreamStateParam().Build(t, happyStateCodec)
encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF)
require.NoError(t, err)
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{
Authcode: happyUpstreamAuthcode,
PKCECodeVerifier: pkce.Code(happyDownstreamPKCE),
ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce),
RedirectURI: happyUpstreamRedirectURI,
}
// Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it
happyDownstreamRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState
tests := []struct {
name string
idp oidctestutil.TestUpstreamOIDCIdentityProvider
method string
path string
csrfCookie string
wantStatus int
wantBody string
wantRedirectLocationRegexp string
wantGrantedOpenidScope bool
wantDownstreamIDTokenSubject string
wantDownstreamIDTokenGroups []string
wantDownstreamRequestedScopes []string
wantDownstreamNonce string
wantDownstreamPKCEChallenge string
wantDownstreamPKCEChallengeMethod string
wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs
}{
{
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamNonce: downstreamNonce,
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups",
idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenGroups: nil,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamNonce: downstreamNonce,
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for",
idp: happyUpstream().WithUsernameClaim("sub").Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamSubject,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamNonce: downstreamNonce,
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
// Pre-upstream-exchange verification
{
name: "PUT method is invalid",
method: http.MethodPut,
path: newRequestPath().String(),
wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: PUT (try GET)\n",
},
{
name: "POST method is invalid",
method: http.MethodPost,
path: newRequestPath().String(),
wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: POST (try GET)\n",
},
{
name: "PATCH method is invalid",
method: http.MethodPatch,
path: newRequestPath().String(),
wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: PATCH (try GET)\n",
},
{
name: "DELETE method is invalid",
method: http.MethodDelete,
path: newRequestPath().String(),
wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: DELETE (try GET)\n",
},
{
name: "code param was not included on request",
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithoutCode().String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: code param not found\n",
},
{
name: "state param was not included on request",
method: http.MethodGet,
path: newRequestPath().WithoutState().String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: state param not found\n",
},
{
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState("this-will-not-decode").String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: error reading state\n",
},
{
// This shouldn't happen in practice because the authorize endpoint should have already run the same
// validations, but we would like to test the error handling in this endpoint anyway.
name: "state param contains authorization request params which fail validation",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()).
Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
wantStatus: http.StatusInternalServerError,
wantBody: "Internal Server Error: error while generating and saving authcode\n",
},
{
name: "state's internal version does not match what we want",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyUpstreamStateParam().WithStateVersion("wrong-state-version").Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: state format version is invalid\n",
},
{
name: "state's downstream auth params element is invalid",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyUpstreamStateParam().
WithAuthorizeRequestParams("the following is an invalid url encoding token, and therefore this is an invalid param: %z").
Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: error reading state downstream auth params\n",
},
{
name: "state's downstream auth params are missing required value (e.g., client_id)",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"client_id": ""}).Encode()).
Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: error using state downstream auth params\n",
},
{
name: "state's downstream auth params does not contain openid scope",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().
WithState(
happyUpstreamStateParam().
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()).
Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamRequestedScopes: []string{"profile", "email"},
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamNonce: downstreamNonce,
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "the UpstreamOIDCProvider CRD has been deleted",
idp: otherUpstreamOIDCIdentityProvider,
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: upstream provider not found\n",
},
{
name: "the CSRF cookie does not exist on request",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
wantStatus: http.StatusForbidden,
wantBody: "Forbidden: CSRF cookie is missing\n",
},
{
name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusForbidden,
wantBody: "Forbidden: error reading CSRF cookie\n",
},
{
name: "cookie csrf value does not match state csrf value",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusForbidden,
wantBody: "Forbidden: CSRF value does not match\n",
},
// Upstream exchange
{
name: "upstream auth code exchange fails",
idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadGateway,
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token does not contain requested username claim",
idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: no username claim in upstream ID token\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token does not contain requested groups claim",
idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: no groups claim in upstream ID token\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token contains username claim with weird format",
idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token does not contain iss claim when using default username claim config",
idp: happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token has an non-string iss claim when using default username claim config",
idp: happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token contains groups claim with weird format",
idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
client := fake.NewSimpleClientset()
secrets := client.CoreV1().Secrets("some-namespace")
// Configure fosite the same way that the production code would.
// Inject this into our test subject at the last second so we get a fresh storage for every test.
oauthStore := oidc.NewKubeStorage(secrets)
hmacSecret := []byte("some secret - must have at least 32 bytes")
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret)
idpListGetter := oidctestutil.NewIDPListGetter(&test.idp)
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI)
req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie)
}
rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req)
t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String())
if test.wantExchangeAndValidateTokensCall != nil {
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
test.wantExchangeAndValidateTokensCall.Ctx = req.Context()
require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0))
} else {
require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
}
require.Equal(t, test.wantStatus, rsp.Code)
if test.wantBody != "" {
require.Equal(t, test.wantBody, rsp.Body.String())
} else {
require.Empty(t, rsp.Body.String())
}
if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test
// Assert that Location header matches regular expression.
require.Len(t, rsp.Header().Values("Location"), 1)
actualLocation := rsp.Header().Get("Location")
regex := regexp.MustCompile(test.wantRedirectLocationRegexp)
submatches := regex.FindStringSubmatch(actualLocation)
require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation)
capturedAuthCode := submatches[1]
// fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface
authcodeDataAndSignature := strings.Split(capturedAuthCode, ".")
require.Len(t, authcodeDataAndSignature, 2)
// Several Secrets should have been created
expectedNumberOfCreatedSecrets := 2
if test.wantGrantedOpenidScope {
expectedNumberOfCreatedSecrets++
}
require.Len(t, client.Actions(), expectedNumberOfCreatedSecrets)
actualSecretNames := []string{}
for i := range client.Actions() {
actualAction := client.Actions()[i].(kubetesting.CreateActionImpl)
require.Equal(t, "create", actualAction.GetVerb())
require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource())
actualSecret := actualAction.GetObject().(*corev1.Secret)
require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace
actualSecretNames = append(actualSecretNames, actualSecret.Name)
}
// One authcode should have been stored.
requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-authcode-")
storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage(
t,
oauthStore,
authcodeDataAndSignature[1], // Authcode store key is authcode signature
test.wantGrantedOpenidScope,
test.wantDownstreamIDTokenSubject,
test.wantDownstreamIDTokenGroups,
test.wantDownstreamRequestedScopes,
)
// One PKCE should have been stored.
requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-pkce-")
validatePKCEStorage(
t,
oauthStore,
authcodeDataAndSignature[1], // PKCE store key is authcode signature
storedRequestFromAuthcode,
storedSessionFromAuthcode,
test.wantDownstreamPKCEChallenge,
test.wantDownstreamPKCEChallengeMethod,
)
// One IDSession should have been stored, if the downstream actually requested the "openid" scope
if test.wantGrantedOpenidScope {
requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-oidc")
validateIDSessionStorage(
t,
oauthStore,
capturedAuthCode, // IDSession store key is full authcode
storedRequestFromAuthcode,
storedSessionFromAuthcode,
test.wantDownstreamNonce,
)
}
}
})
}
}
type requestPath struct {
code, state *string
}
func newRequestPath() *requestPath {
c := happyUpstreamAuthcode
s := "4321"
return &requestPath{
code: &c,
state: &s,
}
}
func (r *requestPath) WithCode(code string) *requestPath {
r.code = &code
return r
}
func (r *requestPath) WithoutCode() *requestPath {
r.code = nil
return r
}
func (r *requestPath) WithState(state string) *requestPath {
r.state = &state
return r
}
func (r *requestPath) WithoutState() *requestPath {
r.state = nil
return r
}
func (r *requestPath) String() string {
path := "/downstream-provider-name/callback?"
params := url.Values{}
if r.code != nil {
params.Add("code", *r.code)
}
if r.state != nil {
params.Add("state", *r.state)
}
return path + params.Encode()
}
type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
func happyUpstreamStateParam() *upstreamStateParamBuilder {
return &upstreamStateParamBuilder{
U: happyUpstreamIDPName,
P: happyDownstreamRequestParams,
N: happyDownstreamNonce,
C: happyDownstreamCSRF,
K: happyDownstreamPKCE,
V: happyDownstreamStateVersion,
}
}
func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string {
state, err := stateEncoder.Encode("s", b)
require.NoError(t, err)
return state
}
func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder {
b.P = params
return b
}
func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder {
b.N = nonce
return b
}
func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder {
b.C = csrf
return b
}
func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder {
b.K = pkce
return b
}
func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder {
b.V = version
return b
}
type upstreamOIDCIdentityProviderBuilder struct {
idToken map[string]interface{}
usernameClaim, groupsClaim string
authcodeExchangeErr error
}
func happyUpstream() *upstreamOIDCIdentityProviderBuilder {
return &upstreamOIDCIdentityProviderBuilder{
usernameClaim: upstreamUsernameClaim,
groupsClaim: upstreamGroupsClaim,
idToken: map[string]interface{}{
"iss": upstreamIssuer,
"sub": upstreamSubject,
upstreamUsernameClaim: upstreamUsername,
upstreamGroupsClaim: upstreamGroupMembership,
"other-claim": "should be ignored",
},
}
}
func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(claim string) *upstreamOIDCIdentityProviderBuilder {
u.usernameClaim = claim
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder {
u.usernameClaim = ""
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder {
u.groupsClaim = ""
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *upstreamOIDCIdentityProviderBuilder {
u.idToken[name] = value
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder {
delete(u.idToken, claim)
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder {
u.authcodeExchangeErr = err
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamOIDCIdentityProvider {
return oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
UsernameClaim: u.usernameClaim,
GroupsClaim: u.groupsClaim,
Scopes: []string{"scope1", "scope2"},
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
return oidctypes.Token{}, u.idToken, u.authcodeExchangeErr
},
}
}
func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values {
copied := url.Values{}
for key, value := range query {
copied[key] = value
}
for key, value := range modifications {
if value == "" {
copied.Del(key)
} else {
copied[key] = []string{value}
}
}
return copied
}
func validateAuthcodeStorage(
t *testing.T,
oauthStore *oidc.KubeStorage,
storeKey string,
wantGrantedOpenidScope bool,
wantDownstreamIDTokenSubject string,
wantDownstreamIDTokenGroups []string,
wantDownstreamRequestedScopes []string,
) (*fosite.Request, *openid.DefaultSession) {
t.Helper()
// Get the authcode session back from storage so we can require that it was stored correctly.
storedAuthorizeRequestFromAuthcode, err := oauthStore.GetAuthorizeCodeSession(context.Background(), storeKey, nil)
require.NoError(t, err)
// Check that storage returned the expected concrete data types.
storedRequestFromAuthcode, storedSessionFromAuthcode := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromAuthcode)
// Check which scopes were granted.
if wantGrantedOpenidScope {
require.Contains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid")
} else {
require.NotContains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid")
}
// Check all the other fields of the stored request.
require.NotEmpty(t, storedRequestFromAuthcode.ID)
require.Equal(t, downstreamClientID, storedRequestFromAuthcode.Client.GetID())
require.ElementsMatch(t, wantDownstreamRequestedScopes, storedRequestFromAuthcode.RequestedScope)
require.Nil(t, storedRequestFromAuthcode.RequestedAudience)
require.Empty(t, storedRequestFromAuthcode.GrantedAudience)
require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequestFromAuthcode.Form)
testutil.RequireTimeInDelta(t, time.Now(), storedRequestFromAuthcode.RequestedAt, timeComparisonFudgeFactor)
// We're not using these fields yet, so confirm that we did not set them (for now).
require.Empty(t, storedSessionFromAuthcode.Subject)
require.Empty(t, storedSessionFromAuthcode.Username)
require.Empty(t, storedSessionFromAuthcode.Headers)
// The authcode that we are issuing should be good for the length of time that we declare in the fosite config.
testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*3), storedSessionFromAuthcode.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor)
require.Len(t, storedSessionFromAuthcode.ExpiresAt, 1)
// Now confirm the ID token claims.
actualClaims := storedSessionFromAuthcode.Claims
// Check the user's identity, which are put into the downstream ID token's subject and groups claims.
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
if wantDownstreamIDTokenGroups != nil {
require.Len(t, actualClaims.Extra, 1)
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
} else {
require.Empty(t, actualClaims.Extra)
require.NotContains(t, actualClaims.Extra, "groups")
}
// Check the rest of the downstream ID token's claims. Fosite wants us to set these (in UTC time).
testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.RequestedAt, timeComparisonFudgeFactor)
testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.AuthTime, timeComparisonFudgeFactor)
requestedAtZone, _ := actualClaims.RequestedAt.Zone()
require.Equal(t, "UTC", requestedAtZone)
authTimeZone, _ := actualClaims.AuthTime.Zone()
require.Equal(t, "UTC", authTimeZone)
// Fosite will set these fields for us in the token endpoint based on the store session
// information. Therefore, we assert that they are empty because we want the library to do the
// lifting for us.
require.Empty(t, actualClaims.Issuer)
require.Nil(t, actualClaims.Audience)
require.Empty(t, actualClaims.Nonce)
require.Zero(t, actualClaims.ExpiresAt)
require.Zero(t, actualClaims.IssuedAt)
// These are not needed yet.
require.Empty(t, actualClaims.JTI)
require.Empty(t, actualClaims.CodeHash)
require.Empty(t, actualClaims.AccessTokenHash)
require.Empty(t, actualClaims.AuthenticationContextClassReference)
require.Empty(t, actualClaims.AuthenticationMethodsReference)
return storedRequestFromAuthcode, storedSessionFromAuthcode
}
func validatePKCEStorage(
t *testing.T,
oauthStore *oidc.KubeStorage,
storeKey string,
storedRequestFromAuthcode *fosite.Request,
storedSessionFromAuthcode *openid.DefaultSession,
wantDownstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod string,
) {
t.Helper()
storedAuthorizeRequestFromPKCE, err := oauthStore.GetPKCERequestSession(context.Background(), storeKey, nil)
require.NoError(t, err)
// Check that storage returned the expected concrete data types.
storedRequestFromPKCE, storedSessionFromPKCE := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromPKCE)
// The stored PKCE request should be the same as the stored authcode request.
require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromPKCE.ID)
require.Equal(t, storedSessionFromAuthcode, storedSessionFromPKCE)
// The stored PKCE request should also contain the PKCE challenge that the downstream sent us.
require.Equal(t, wantDownstreamPKCEChallenge, storedRequestFromPKCE.Form.Get("code_challenge"))
require.Equal(t, wantDownstreamPKCEChallengeMethod, storedRequestFromPKCE.Form.Get("code_challenge_method"))
}
func validateIDSessionStorage(
t *testing.T,
oauthStore *oidc.KubeStorage,
storeKey string,
storedRequestFromAuthcode *fosite.Request,
storedSessionFromAuthcode *openid.DefaultSession,
wantDownstreamNonce string,
) {
t.Helper()
storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil)
require.NoError(t, err)
// Check that storage returned the expected concrete data types.
storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession)
// The stored IDSession request should be the same as the stored authcode request.
require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID)
require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession)
// The stored IDSession request should also contain the nonce that the downstream sent us.
require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce"))
}
func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requester) (*fosite.Request, *openid.DefaultSession) {
t.Helper()
storedRequest, ok := storedAuthorizeRequest.(*fosite.Request)
require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest, &fosite.Request{})
storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession)
require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest.GetSession(), &openid.DefaultSession{})
return storedRequest, storedSession
}
func requireAnyStringHasPrefix(t *testing.T, stringList []string, prefix string) {
t.Helper()
containsPrefix := false
for i := range stringList {
if strings.HasPrefix(stringList[i], prefix) {
containsPrefix = true
}
}
require.Truef(t, containsPrefix, "list %v did not contain any strings with prefix %s", stringList, prefix)
}

View File

@ -0,0 +1,120 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package oidc
import (
"context"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/handler/openid"
fositepkce "github.com/ory/fosite/handler/pkce"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
"go.pinniped.dev/internal/constable"
"go.pinniped.dev/internal/fositestorage/authorizationcode"
"go.pinniped.dev/internal/fositestorage/openidconnect"
"go.pinniped.dev/internal/fositestorage/pkce"
)
const errKubeStorageNotImplemented = constable.Error("KubeStorage does not implement this method. It should not have been called.")
type KubeStorage struct {
authorizationCodeStorage oauth2.AuthorizeCodeStorage
pkceStorage fositepkce.PKCERequestStorage
oidcStorage openid.OpenIDConnectRequestStorage
}
func NewKubeStorage(secrets corev1client.SecretInterface) *KubeStorage {
return &KubeStorage{
authorizationCodeStorage: authorizationcode.New(secrets),
pkceStorage: pkce.New(secrets),
oidcStorage: openidconnect.New(secrets),
}
}
func (KubeStorage) RevokeRefreshToken(_ context.Context, _ string) error {
return errKubeStorageNotImplemented
}
func (KubeStorage) RevokeAccessToken(_ context.Context, _ string) error {
return errKubeStorageNotImplemented
}
func (KubeStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
return nil
}
func (KubeStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
return nil, errKubeStorageNotImplemented
}
func (KubeStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) {
return errKubeStorageNotImplemented
}
func (KubeStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
return nil
}
func (KubeStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
return nil, errKubeStorageNotImplemented
}
func (KubeStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) {
return errKubeStorageNotImplemented
}
func (k KubeStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error {
return k.oidcStorage.CreateOpenIDConnectSession(ctx, authcode, requester)
}
func (k KubeStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) (fosite.Requester, error) {
return k.oidcStorage.GetOpenIDConnectSession(ctx, authcode, requester)
}
func (k KubeStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error {
return k.oidcStorage.DeleteOpenIDConnectSession(ctx, authcode)
}
func (k KubeStorage) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) {
return k.pkceStorage.GetPKCERequestSession(ctx, signature, session)
}
func (k KubeStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error {
return k.pkceStorage.CreatePKCERequestSession(ctx, signature, requester)
}
func (k KubeStorage) DeletePKCERequestSession(ctx context.Context, signature string) error {
return k.pkceStorage.DeletePKCERequestSession(ctx, signature)
}
func (k KubeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, r fosite.Requester) (err error) {
return k.authorizationCodeStorage.CreateAuthorizeCodeSession(ctx, signature, r)
}
func (k KubeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, s fosite.Session) (request fosite.Requester, err error) {
return k.authorizationCodeStorage.GetAuthorizeCodeSession(ctx, signature, s)
}
func (k KubeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) {
return k.authorizationCodeStorage.InvalidateAuthorizeCodeSession(ctx, signature)
}
func (KubeStorage) GetClient(_ context.Context, id string) (fosite.Client, error) {
client := PinnipedCLIOIDCClient()
if client.ID == id {
return client, nil
}
return nil, fosite.ErrNotFound
}
func (KubeStorage) ClientAssertionJWTValid(_ context.Context, _ string) error {
return errKubeStorageNotImplemented
}
func (KubeStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error {
return errKubeStorageNotImplemented
}

View File

@ -12,16 +12,16 @@ import (
"go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/constable"
) )
const errNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.") const errNullStorageNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.")
type NullStorage struct{} type NullStorage struct{}
func (NullStorage) RevokeRefreshToken(_ context.Context, _ string) error { func (NullStorage) RevokeRefreshToken(_ context.Context, _ string) error {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) RevokeAccessToken(_ context.Context, _ string) error { func (NullStorage) RevokeAccessToken(_ context.Context, _ string) error {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
@ -29,11 +29,11 @@ func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosi
} }
func (NullStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { func (NullStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
return nil, errNotImplemented return nil, errNullStorageNotImplemented
} }
func (NullStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) { func (NullStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) { func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
@ -41,11 +41,11 @@ func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosit
} }
func (NullStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { func (NullStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
return nil, errNotImplemented return nil, errNullStorageNotImplemented
} }
func (NullStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) { func (NullStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error { func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error {
@ -53,15 +53,15 @@ func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fos
} }
func (NullStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) { func (NullStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) {
return nil, errNotImplemented return nil, errNullStorageNotImplemented
} }
func (NullStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error { func (NullStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) { func (NullStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) {
return nil, errNotImplemented return nil, errNullStorageNotImplemented
} }
func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error { func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error {
@ -69,7 +69,7 @@ func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosit
} }
func (NullStorage) DeletePKCERequestSession(_ context.Context, _ string) error { func (NullStorage) DeletePKCERequestSession(_ context.Context, _ string) error {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Requester) (err error) { func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
@ -77,11 +77,11 @@ func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fos
} }
func (NullStorage) GetAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) { func (NullStorage) GetAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
return nil, errNotImplemented return nil, errNullStorageNotImplemented
} }
func (NullStorage) InvalidateAuthorizeCodeSession(_ context.Context, _ string) (err error) { func (NullStorage) InvalidateAuthorizeCodeSession(_ context.Context, _ string) (err error) {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error) { func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error) {
@ -93,9 +93,9 @@ func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error
} }
func (NullStorage) ClientAssertionJWTValid(_ context.Context, _ string) error { func (NullStorage) ClientAssertionJWTValid(_ context.Context, _ string) error {
return errNotImplemented return errNullStorageNotImplemented
} }
func (NullStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error { func (NullStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error {
return errNotImplemented return errNullStorageNotImplemented
} }

View File

@ -10,15 +10,72 @@ import (
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/compose" "github.com/ory/fosite/compose"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/pkce"
) )
const ( const (
WellKnownEndpointPath = "/.well-known/openid-configuration" WellKnownEndpointPath = "/.well-known/openid-configuration"
AuthorizationEndpointPath = "/oauth2/authorize" AuthorizationEndpointPath = "/oauth2/authorize"
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
CallbackEndpointPath = "/callback"
JWKSEndpointPath = "/jwks.json" JWKSEndpointPath = "/jwks.json"
) )
const (
// Just in case we need to make a breaking change to the format of the upstream state param,
// we are including a format version number. This gives the opportunity for a future version of Pinniped
// to have the consumer of this format decide to reject versions that it doesn't understand.
UpstreamStateParamFormatVersion = "1"
// The `name` passed to the encoder for encoding the upstream state param value. This name is short
// because it will be encoded into the upstream state param value and we're trying to keep that small.
UpstreamStateParamEncodingName = "s"
// CSRFCookieName is the name of the browser cookie which shall hold our CSRF value.
// The `__Host` prefix has a special meaning. See:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes.
CSRFCookieName = "__Host-pinniped-csrf"
// CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF
// cookie contents.
CSRFCookieEncodingName = "csrf"
)
// Encoder is the encoding side of the securecookie.Codec interface.
type Encoder interface {
Encode(name string, value interface{}) (string, error)
}
// Decoder is the decoding side of the securecookie.Codec interface.
type Decoder interface {
Decode(name, value string, into interface{}) error
}
// Codec is both the encoding and decoding sides of the securecookie.Codec interface. It is
// interface'd here so that we properly wrap the securecookie dependency.
type Codec interface {
Encoder
Decoder
}
// UpstreamStateParamData is the format of the state parameter that we use when we communicate to an
// upstream OIDC provider.
//
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on
// the state param.
type UpstreamStateParamData struct {
AuthParams string `json:"p"`
UpstreamName string `json:"u"`
Nonce nonce.Nonce `json:"n"`
CSRFToken csrftoken.CSRFToken `json:"c"`
PKCECode pkce.Code `json:"k"`
FormatVersion string `json:"v"`
}
func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient {
return &fosite.DefaultOpenIDConnectClient{ return &fosite.DefaultOpenIDConnectClient{
DefaultClient: &fosite.DefaultClient{ DefaultClient: &fosite.DefaultClient{
@ -34,8 +91,8 @@ func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient {
} }
func FositeOauth2Helper( func FositeOauth2Helper(
issuerURL string, oauthStore interface{},
oauthStore fosite.Storage, issuer string,
hmacSecretOfLengthAtLeast32 []byte, hmacSecretOfLengthAtLeast32 []byte,
jwtSigningKey *ecdsa.PrivateKey, jwtSigningKey *ecdsa.PrivateKey,
) fosite.OAuth2Provider { ) fosite.OAuth2Provider {
@ -47,7 +104,7 @@ func FositeOauth2Helper(
RefreshTokenLifespan: 16 * time.Hour, // long enough for a single workday RefreshTokenLifespan: 16 * time.Hour, // long enough for a single workday
IDTokenIssuer: issuerURL, IDTokenIssuer: issuer,
TokenURL: "", // TODO set once we have this endpoint written TokenURL: "", // TODO set once we have this endpoint written
ScopeStrategy: fosite.ExactScopeStrategy, // be careful and only support exact string matching for scopes ScopeStrategy: fosite.ExactScopeStrategy, // be careful and only support exact string matching for scopes
@ -75,3 +132,7 @@ func FositeOauth2Helper(
compose.OAuth2PKCEFactory, compose.OAuth2PKCEFactory,
) )
} }
type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProviderI
}

View File

@ -0,0 +1,129 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package oidctestutil
import (
"context"
"net/url"
"golang.org/x/oauth2"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
)
// Test helpers for the OIDC package.
// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc().
type ExchangeAuthcodeAndValidateTokenArgs struct {
Ctx context.Context
Authcode string
PKCECodeVerifier pkce.Code
ExpectedIDTokenNonce nonce.Nonce
RedirectURI string
}
type TestUpstreamOIDCIdentityProvider struct {
Name string
ClientID string
AuthorizationURL url.URL
UsernameClaim string
GroupsClaim string
Scopes []string
ExchangeAuthcodeAndValidateTokensFunc func(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
) (oidctypes.Token, map[string]interface{}, error)
exchangeAuthcodeAndValidateTokensCallCount int
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
}
func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
return u.Name
}
func (u *TestUpstreamOIDCIdentityProvider) GetClientID() string {
return u.ClientID
}
func (u *TestUpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL {
return &u.AuthorizationURL
}
func (u *TestUpstreamOIDCIdentityProvider) GetScopes() []string {
return u.Scopes
}
func (u *TestUpstreamOIDCIdentityProvider) GetUsernameClaim() string {
return u.UsernameClaim
}
func (u *TestUpstreamOIDCIdentityProvider) GetGroupsClaim() string {
return u.GroupsClaim
}
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
redirectURI string,
) (oidctypes.Token, map[string]interface{}, error) {
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
}
u.exchangeAuthcodeAndValidateTokensCallCount++
u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{
Ctx: ctx,
Authcode: authcode,
PKCECodeVerifier: pkceCodeVerifier,
ExpectedIDTokenNonce: expectedIDTokenNonce,
RedirectURI: redirectURI,
})
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
}
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int {
return u.exchangeAuthcodeAndValidateTokensCallCount
}
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs {
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
}
return u.exchangeAuthcodeAndValidateTokensArgs[call]
}
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
panic("implement me")
}
func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider()
upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders))
for i := range upstreamOIDCIdentityProviders {
upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i])
}
idpProvider.SetIDPList(upstreams)
return idpProvider
}
// Declare a separate type from the production code to ensure that the state param's contents was serialized
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
// assertions about the redirect URL in this test.
type ExpectedUpstreamStateParamFormat struct {
P string `json:"p"`
U string `json:"u"`
N string `json:"n"`
C string `json:"c"`
K string `json:"k"`
V string `json:"v"`
}

View File

@ -4,48 +4,73 @@
package provider package provider
import ( import (
"context"
"net/url" "net/url"
"sync" "sync"
"golang.org/x/oauth2"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
) )
type UpstreamOIDCIdentityProvider struct { type UpstreamOIDCIdentityProviderI interface {
// A name for this upstream provider, which will be used as a component of the path for the callback endpoint // A name for this upstream provider, which will be used as a component of the path for the callback endpoint
// hosted by the Supervisor. // hosted by the Supervisor.
Name string GetName() string
// The Oauth client ID registered with the upstream provider to be used in the authorization flow. // The Oauth client ID registered with the upstream provider to be used in the authorization code flow.
ClientID string GetClientID() string
// The Authorization Endpoint fetched from discovery. // The Authorization Endpoint fetched from discovery.
AuthorizationURL url.URL GetAuthorizationURL() *url.URL
// Scopes to request in authorization flow. // Scopes to request in authorization flow.
Scopes []string GetScopes() []string
// ID Token username claim name. May return empty string, in which case we will use some reasonable defaults.
GetUsernameClaim() string
// ID Token groups claim name. May return empty string, in which case we won't try to read groups from the upstream provider.
GetGroupsClaim() string
// Performs upstream OIDC authorization code exchange and token validation.
// Returns the validated raw tokens as well as the parsed claims of the ID token.
ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
redirectURI string,
) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error)
ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error)
} }
type DynamicUpstreamIDPProvider interface { type DynamicUpstreamIDPProvider interface {
SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI)
GetIDPList() []UpstreamOIDCIdentityProvider GetIDPList() []UpstreamOIDCIdentityProviderI
} }
type dynamicUpstreamIDPProvider struct { type dynamicUpstreamIDPProvider struct {
oidcProviders []UpstreamOIDCIdentityProvider oidcProviders []UpstreamOIDCIdentityProviderI
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider { func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider {
return &dynamicUpstreamIDPProvider{ return &dynamicUpstreamIDPProvider{
oidcProviders: []UpstreamOIDCIdentityProvider{}, oidcProviders: []UpstreamOIDCIdentityProviderI{},
} }
} }
func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) { func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) {
p.mutex.Lock() // acquire a write lock p.mutex.Lock() // acquire a write lock
defer p.mutex.Unlock() defer p.mutex.Unlock()
p.oidcProviders = oidcIDPs p.oidcProviders = oidcIDPs
} }
func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProvider { func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProviderI {
p.mutex.RLock() // acquire a read lock p.mutex.RLock() // acquire a read lock
defer p.mutex.RUnlock() defer p.mutex.RUnlock()
return p.oidcProviders return p.oidcProviders

View File

@ -9,9 +9,11 @@ import (
"sync" "sync"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/auth" "go.pinniped.dev/internal/oidc/auth"
"go.pinniped.dev/internal/oidc/callback"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
@ -30,19 +32,26 @@ type Manager struct {
providerHandlers map[string]http.Handler // map of all routes for all providers providerHandlers map[string]http.Handler // map of all routes for all providers
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
secretsClient corev1client.SecretInterface
} }
// NewManager returns an empty Manager. // NewManager returns an empty Manager.
// nextHandler will be invoked for any requests that could not be handled by this manager's providers. // nextHandler will be invoked for any requests that could not be handled by this manager's providers.
// dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data.
// idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs.
func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager { func NewManager(
nextHandler http.Handler,
dynamicJWKSProvider jwks.DynamicJWKSProvider,
idpListGetter oidc.IDPListGetter,
secretsClient corev1client.SecretInterface,
) *Manager {
return &Manager{ return &Manager{
providerHandlers: make(map[string]http.Handler), providerHandlers: make(map[string]http.Handler),
nextHandler: nextHandler, nextHandler: nextHandler,
dynamicJWKSProvider: dynamicJWKSProvider, dynamicJWKSProvider: dynamicJWKSProvider,
idpListGetter: idpListGetter, idpListGetter: idpListGetter,
secretsClient: secretsClient,
} }
} }
@ -62,20 +71,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
m.providerHandlers = make(map[string]http.Handler) m.providerHandlers = make(map[string]http.Handler)
for _, incomingProvider := range oidcProviders { for _, incomingProvider := range oidcProviders {
wellKnownURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.WellKnownEndpointPath issuer := incomingProvider.Issuer()
m.providerHandlers[wellKnownURL] = discovery.NewHandler(incomingProvider.Issuer()) issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath()
jwksURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.JWKSEndpointPath fositeHMACSecretForThisProvider := []byte("some secret - must have at least 32 bytes") // TODO replace this secret
m.providerHandlers[jwksURL] = jwks.NewHandler(incomingProvider.Issuer(), m.dynamicJWKSProvider)
// Use NullStorage for the authorize endpoint because we do not actually want to store anything until // Use NullStorage for the authorize endpoint because we do not actually want to store anything until
// the upstream callback endpoint is called later. // the upstream callback endpoint is called later.
oauthHelper := oidc.FositeOauth2Helper( oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, fositeHMACSecretForThisProvider, nil)
incomingProvider.Issuer(),
oidc.NullStorage{}, // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage.
[]byte("some secret - must have at least 32 bytes"), // TODO replace this secret oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, fositeHMACSecretForThisProvider, nil)
nil, // TODO: inject me properly
)
// TODO use different codecs for the state and the cookie, because: // TODO use different codecs for the state and the cookie, because:
// 1. we would like to state to have an embedded expiration date while the cookie does not need that // 1. we would like to state to have an embedded expiration date while the cookie does not need that
@ -86,10 +92,30 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
var encoder = securecookie.New(encoderHashKey, encoderBlockKey) var encoder = securecookie.New(encoderHashKey, encoderBlockKey)
encoder.SetSerializer(securecookie.JSONEncoder{}) encoder.SetSerializer(securecookie.JSONEncoder{})
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer)
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder)
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider)
m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler(
issuer,
m.idpListGetter,
oauthHelperWithNullStorage,
csrftoken.Generate,
pkce.Generate,
nonce.Generate,
encoder,
encoder,
)
m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler(
m.idpListGetter,
oauthHelperWithKubeStorage,
encoder,
encoder,
issuer+oidc.CallbackEndpointPath,
)
plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer)
} }
} }

View File

@ -4,6 +4,7 @@
package manager package manager
import ( import (
"context"
"encoding/json" "encoding/json"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -15,12 +16,17 @@ import (
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"k8s.io/client-go/kubernetes/fake"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
) )
func TestManager(t *testing.T) { func TestManager(t *testing.T) {
@ -31,6 +37,7 @@ func TestManager(t *testing.T) {
nextHandler http.HandlerFunc nextHandler http.HandlerFunc
fallbackHandlerWasCalled bool fallbackHandlerWasCalled bool
dynamicJWKSProvider jwks.DynamicJWKSProvider dynamicJWKSProvider jwks.DynamicJWKSProvider
kubeClient *fake.Clientset
) )
const ( const (
@ -41,6 +48,7 @@ func TestManager(t *testing.T) {
issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path"
issuer2KeyID = "issuer2-key" issuer2KeyID = "issuer2-key"
upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" upstreamIDPAuthorizationURL = "https://test-upstream.com/auth"
downstreamRedirectURL = "http://127.0.0.1:12345/callback"
) )
newGetRequest := func(url string) *http.Request { newGetRequest := func(url string) *http.Request {
@ -64,7 +72,7 @@ func TestManager(t *testing.T) {
r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer) r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer)
} }
requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) { requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix)) subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix))
@ -79,6 +87,58 @@ func TestManager(t *testing.T) {
"actual location %s did not start with expected prefix %s", "actual location %s did not start with expected prefix %s",
actualLocation, expectedRedirectLocationPrefix, actualLocation, expectedRedirectLocationPrefix,
) )
parsedLocation, err := url.Parse(actualLocation)
r.NoError(err)
redirectStateParam := parsedLocation.Query().Get("state")
r.NotEmpty(redirectStateParam)
cookieValueAndDirectivesSplit := strings.SplitN(recorder.Header().Get("Set-Cookie"), ";", 2)
r.Len(cookieValueAndDirectivesSplit, 2)
cookieKeyValueSplit := strings.Split(cookieValueAndDirectivesSplit[0], "=")
r.Len(cookieKeyValueSplit, 2)
csrfCookieName := cookieKeyValueSplit[0]
r.Equal("__Host-pinniped-csrf", csrfCookieName)
csrfCookieValue := cookieKeyValueSplit[1]
r.NotEmpty(csrfCookieValue)
// Return the important parts of the response so we can use them in our next request to the callback endpoint
return csrfCookieValue, redirectStateParam
}
requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix, csrfCookieValue string) {
recorder := httptest.NewRecorder()
numberOfKubeActionsBeforeThisRequest := len(kubeClient.Actions())
getRequest := newGetRequest(requestIssuer + oidc.CallbackEndpointPath + requestURLSuffix)
getRequest.AddCookie(&http.Cookie{
Name: "__Host-pinniped-csrf",
Value: csrfCookieValue,
})
subject.ServeHTTP(recorder, getRequest)
r.False(fallbackHandlerWasCalled)
// Check just enough of the response to ensure that we wired up the callback endpoint correctly.
// The endpoint's own unit tests cover everything else.
r.Equal(http.StatusFound, recorder.Code)
actualLocation := recorder.Header().Get("Location")
r.True(
strings.HasPrefix(actualLocation, downstreamRedirectURL),
"actual location %s did not start with expected prefix %s",
actualLocation, downstreamRedirectURL,
)
parsedLocation, err := url.Parse(actualLocation)
r.NoError(err)
actualLocationQueryParams := parsedLocation.Query()
r.Contains(actualLocationQueryParams, "code")
r.Equal("openid", actualLocationQueryParams.Get("scope"))
r.Equal("some-state-value-that-is-32-byte", actualLocationQueryParams.Get("state"))
// Make sure that we wired up the callback endpoint to use kube storage for fosite sessions.
r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3,
"did not perform any kube actions during the callback request, but should have")
} }
requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) {
@ -107,17 +167,27 @@ func TestManager(t *testing.T) {
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
r.NoError(err) r.NoError(err)
idpListGetter := provider.NewDynamicUpstreamIDPProvider() idpListGetter := oidctestutil.NewIDPListGetter(&oidctestutil.TestUpstreamOIDCIdentityProvider{
idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{
{
Name: "test-idp", Name: "test-idp",
ClientID: "test-client-id", ClientID: "test-client-id",
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
Scopes: []string{"test-scope"}, Scopes: []string{"test-scope"},
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
return oidctypes.Token{},
map[string]interface{}{
"iss": "https://some-issuer.com",
"sub": "some-subject",
"username": "test-username",
"groups": "test-group1",
},
nil
}, },
}) })
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) kubeClient = fake.NewSimpleClientset()
secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, secretsClient)
}) })
when("given no providers via SetProviders()", func() { when("given no providers via SetProviders()", func() {
@ -164,7 +234,6 @@ func TestManager(t *testing.T) {
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID)
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID)
authRedirectURI := "http://127.0.0.1/callback"
authRequestParams := "?" + url.Values{ authRequestParams := "?" + url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"openid profile email"}, "scope": []string{"openid profile email"},
@ -173,7 +242,7 @@ func TestManager(t *testing.T) {
"nonce": []string{"some-nonce-value"}, "nonce": []string{"some-nonce-value"},
"code_challenge": []string{"some-challenge"}, "code_challenge": []string{"some-challenge"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"redirect_uri": []string{authRedirectURI}, "redirect_uri": []string{downstreamRedirectURL},
}.Encode() }.Encode()
requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL)
@ -181,7 +250,20 @@ func TestManager(t *testing.T) {
// Hostnames are case-insensitive, so test that we can handle that. // Hostnames are case-insensitive, so test that we can handle that.
requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
csrfCookieValue, upstreamStateParam :=
requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
callbackRequestParams := "?" + url.Values{
"code": []string{"some-fake-code"},
"state": []string{upstreamStateParam},
}.Encode()
requireCallbackRequestToBeHandled(issuer1, callbackRequestParams, csrfCookieValue)
requireCallbackRequestToBeHandled(issuer2, callbackRequestParams, csrfCookieValue)
// // Hostnames are case-insensitive, so test that we can handle that.
requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams, csrfCookieValue)
requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams, csrfCookieValue)
} }
when("given some valid providers via SetProviders()", func() { when("given some valid providers via SetProviders()", func() {

View File

@ -0,0 +1,24 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func RequireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
require.InDeltaf(t,
float64(t1.UnixNano()),
float64(t2.UnixNano()),
float64(delta.Nanoseconds()),
"expected %s and %s to be < %s apart, but they are %s apart",
t1.Format(time.RFC3339Nano),
t2.Format(time.RFC3339Nano),
delta.String(),
t1.Sub(t2).String(),
)
}

View File

@ -0,0 +1,117 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions.
package upstreamoidc
import (
"context"
"net/http"
"net/url"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
)
func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
return &ProviderConfig{Config: config, Provider: provider, Client: client}
}
// ProviderConfig holds the active configuration of an upstream OIDC provider.
type ProviderConfig struct {
Name string
UsernameClaim string
GroupsClaim string
Config *oauth2.Config
Provider interface {
Verifier(*oidc.Config) *oidc.IDTokenVerifier
}
Client *http.Client
}
func (p *ProviderConfig) GetName() string {
return p.Name
}
func (p *ProviderConfig) GetClientID() string {
return p.Config.ClientID
}
func (p *ProviderConfig) GetAuthorizationURL() *url.URL {
result, _ := url.Parse(p.Config.Endpoint.AuthURL)
return result
}
func (p *ProviderConfig) GetScopes() []string {
return p.Config.Scopes
}
func (p *ProviderConfig) GetUsernameClaim() string {
return p.UsernameClaim
}
func (p *ProviderConfig) GetGroupsClaim() string {
return p.GroupsClaim
}
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) {
tok, err := p.Config.Exchange(
oidc.ClientContext(ctx, p.Client),
authcode,
pkceCodeVerifier.Verifier(),
oauth2.SetAuthURLParam("redirect_uri", redirectURI),
)
if err != nil {
return oidctypes.Token{}, nil, err
}
return p.ValidateToken(ctx, tok, expectedIDTokenNonce)
}
func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok)
if err != nil {
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if expectedIDTokenNonce != "" {
if err := expectedIDTokenNonce.Validate(validated); err != nil {
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
var validatedClaims map[string]interface{}
if err := validated.Claims(&validatedClaims); err != nil {
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err)
}
return oidctypes.Token{
AccessToken: &oidctypes.AccessToken{
Token: tok.AccessToken,
Type: tok.TokenType,
Expiry: metav1.NewTime(tok.Expiry),
},
RefreshToken: &oidctypes.RefreshToken{
Token: tok.RefreshToken,
},
IDToken: &oidctypes.IDToken{
Token: idTok,
Expiry: metav1.NewTime(validated.Expiry),
},
}, validatedClaims, nil
}

View File

@ -0,0 +1,223 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package upstreamoidc
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/mocks/mockkeyset"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
)
func TestProviderConfig(t *testing.T) {
t.Run("getters get", func(t *testing.T) {
p := ProviderConfig{
Name: "test-name",
UsernameClaim: "test-username-claim",
GroupsClaim: "test-groups-claim",
Config: &oauth2.Config{
ClientID: "test-client-id",
Endpoint: oauth2.Endpoint{AuthURL: "https://example.com"},
Scopes: []string{"scope1", "scope2"},
},
}
require.Equal(t, "test-name", p.GetName())
require.Equal(t, "test-client-id", p.GetClientID())
require.Equal(t, "https://example.com", p.GetAuthorizationURL().String())
require.ElementsMatch(t, []string{"scope1", "scope2"}, p.GetScopes())
require.Equal(t, "test-username-claim", p.GetUsernameClaim())
require.Equal(t, "test-groups-claim", p.GetGroupsClaim())
})
const (
// Test JWTs generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
invalidAccessTokenHashIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg" //nolint: gosec
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
invalidNonceIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ" //nolint: gosec
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"foo": "bar", "bat": "baz"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
validIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImJhdCI6ImJheiIsImZvbyI6ImJhciIsImlhdCI6MTYwNjc2ODU5MywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDY3Njg1OTMsInN1YiI6InRlc3QtdXNlciJ9.DuqVZ7pGhHqKz7gNr4j2W1s1N8YrSltktH4wW19L4oD1OE2-O72jAnNj5xdjilsa8l7h9ox-5sMF0Tkh3BdRlHQK9dEtNm9tW-JreUnWJ3LCqUs-LZp4NG7edvq2sH_1Bn7O2_NQV51s8Pl04F60CndjQ4NM-6WkqDQTKyY6vJXU7idvM-6TM2HJZK-Na88cOJ9KIK37tL5DhcbsHVF47Dq8uPZ0KbjNQjJLAIi_1GeQBgc6yJhDUwRY4Xu6S0dtTHA6xTI8oSXoamt4bkViEHfJBp97LZQiNz8mku5pVc0aNwP1p4hMHxRHhLXrJjbh-Hx4YFjxtOnIq9t1mHlD4A" //nolint: gosec
)
tests := []struct {
name string
authCode string
expectNonce nonce.Nonce
returnIDTok string
wantErr string
wantToken oidctypes.Token
wantClaims map[string]interface{}
}{
{
name: "exchange fails with network error",
authCode: "invalid-auth-code",
wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n",
},
{
name: "missing ID token",
authCode: "valid",
wantErr: "received response missing ID token",
},
{
name: "invalid ID token",
authCode: "valid",
returnIDTok: "invalid-jwt",
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
},
{
name: "invalid access token hash",
authCode: "valid",
returnIDTok: invalidAccessTokenHashIDToken,
wantErr: "received invalid ID token: access token hash does not match value in ID token",
},
{
name: "invalid nonce",
authCode: "valid",
expectNonce: "test-nonce",
returnIDTok: invalidNonceIDToken,
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
},
{
name: "invalid nonce but not checked",
authCode: "valid",
expectNonce: "",
returnIDTok: invalidNonceIDToken,
wantToken: oidctypes.Token{
AccessToken: &oidctypes.AccessToken{
Token: "test-access-token",
Expiry: metav1.Time{},
},
RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token",
},
IDToken: &oidctypes.IDToken{
Token: invalidNonceIDToken,
Expiry: metav1.Time{},
},
},
},
{
name: "valid",
authCode: "valid",
returnIDTok: validIDToken,
wantToken: oidctypes.Token{
AccessToken: &oidctypes.AccessToken{
Token: "test-access-token",
Expiry: metav1.Time{},
},
RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token",
},
IDToken: &oidctypes.IDToken{
Token: validIDToken,
Expiry: metav1.Time{},
},
},
wantClaims: map[string]interface{}{
"foo": "bar",
"bat": "baz",
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.NoError(t, r.ParseForm())
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
require.NotEmpty(t, r.Form.Get("code"))
if r.Form.Get("code") != "valid" {
http.Error(w, "invalid authorization code", http.StatusForbidden)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
}
response.AccessToken = "test-access-token"
response.RefreshToken = "test-refresh-token"
response.Expiry = time.Now().Add(time.Hour)
response.IDToken = tt.returnIDTok
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}))
t.Cleanup(tokenServer.Close)
p := ProviderConfig{
Name: "test-name",
UsernameClaim: "test-username-claim",
GroupsClaim: "test-groups-claim",
Config: &oauth2.Config{
ClientID: "test-client-id",
Endpoint: oauth2.Endpoint{
AuthURL: "https://example.com",
TokenURL: tokenServer.URL,
AuthStyle: oauth2.AuthStyleInParams,
},
Scopes: []string{"scope1", "scope2"},
},
Provider: &mockProvider{},
}
ctx := context.Background()
tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback")
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Equal(t, oidctypes.Token{}, tok)
require.Nil(t, claims)
return
}
require.NoError(t, err)
require.Equal(t, tt.wantToken, tok)
for k, v := range tt.wantClaims {
require.Equal(t, v, claims[k])
}
})
}
}
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else.
func mockVerifier() *oidc.IDTokenVerifier {
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil))
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()).
AnyTimes().
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
return jws.UnsafePayloadWithoutVerification(), nil
})
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
SkipIssuerCheck: true,
SkipExpiryCheck: true,
SkipClientIDCheck: true,
})
}
type mockProvider struct{}
func (m *mockProvider) Verifier(_ *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }

View File

@ -17,6 +17,7 @@ import (
"sigs.k8s.io/yaml" "sigs.k8s.io/yaml"
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
) )
var ( var (
@ -48,7 +49,7 @@ type (
Key oidcclient.SessionCacheKey `json:"key"` Key oidcclient.SessionCacheKey `json:"key"`
CreationTimestamp metav1.Time `json:"creationTimestamp"` CreationTimestamp metav1.Time `json:"creationTimestamp"`
LastUsedTimestamp metav1.Time `json:"lastUsedTimestamp"` LastUsedTimestamp metav1.Time `json:"lastUsedTimestamp"`
Tokens oidcclient.Token `json:"tokens"` Tokens oidctypes.Token `json:"tokens"`
} }
) )

View File

@ -13,6 +13,7 @@ import (
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
) )
// validSession should be the same data as `testdata/valid.yaml`. // validSession should be the same data as `testdata/valid.yaml`.
@ -28,17 +29,17 @@ var validSession = sessionCache{
}, },
CreationTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 42, 7, 0, time.UTC).Local()), CreationTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 42, 7, 0, time.UTC).Local()),
LastUsedTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 45, 31, 0, time.UTC).Local()), LastUsedTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 45, 31, 0, time.UTC).Local()),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 46, 30, 0, time.UTC).Local()), Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 46, 30, 0, time.UTC).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()), Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token", Token: "test-refresh-token",
}, },
}, },
@ -140,8 +141,8 @@ func TestNormalized(t *testing.T) {
// ID token is empty, but not nil. // ID token is empty, but not nil.
{ {
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "", Token: "",
Expiry: metav1.NewTime(now.Add(1 * time.Minute)), Expiry: metav1.NewTime(now.Add(1 * time.Minute)),
}, },
@ -150,8 +151,8 @@ func TestNormalized(t *testing.T) {
// ID token is expired. // ID token is expired.
{ {
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(-1 * time.Minute)), Expiry: metav1.NewTime(now.Add(-1 * time.Minute)),
}, },
@ -160,8 +161,8 @@ func TestNormalized(t *testing.T) {
// Access token is empty, but not nil. // Access token is empty, but not nil.
{ {
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "", Token: "",
Expiry: metav1.NewTime(now.Add(1 * time.Minute)), Expiry: metav1.NewTime(now.Add(1 * time.Minute)),
}, },
@ -170,8 +171,8 @@ func TestNormalized(t *testing.T) {
// Access token is expired. // Access token is expired.
{ {
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
Expiry: metav1.NewTime(now.Add(-1 * time.Minute)), Expiry: metav1.NewTime(now.Add(-1 * time.Minute)),
}, },
@ -180,8 +181,8 @@ func TestNormalized(t *testing.T) {
// Refresh token is empty, but not nil. // Refresh token is empty, but not nil.
{ {
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "", Token: "",
}, },
}, },
@ -189,8 +190,8 @@ func TestNormalized(t *testing.T) {
// Session has a refresh token but it hasn't been used in >90 days. // Session has a refresh token but it hasn't been used in >90 days.
{ {
LastUsedTimestamp: metav1.NewTime(now.AddDate(-1, 0, 0)), LastUsedTimestamp: metav1.NewTime(now.AddDate(-1, 0, 0)),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token", Token: "test-refresh-token",
}, },
}, },
@ -199,8 +200,8 @@ func TestNormalized(t *testing.T) {
{ {
CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token2", Token: "test-refresh-token2",
}, },
}, },
@ -208,8 +209,8 @@ func TestNormalized(t *testing.T) {
{ {
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token1", Token: "test-refresh-token1",
}, },
}, },
@ -223,8 +224,8 @@ func TestNormalized(t *testing.T) {
{ {
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token1", Token: "test-refresh-token1",
}, },
}, },
@ -232,8 +233,8 @@ func TestNormalized(t *testing.T) {
{ {
CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now), LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token2", Token: "test-refresh-token2",
}, },
}, },

View File

@ -16,6 +16,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
) )
const ( const (
@ -65,14 +66,14 @@ type Cache struct {
} }
// GetToken looks up the cached data for the given parameters. It may return nil if no valid matching session is cached. // GetToken looks up the cached data for the given parameters. It may return nil if no valid matching session is cached.
func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token { func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidctypes.Token {
// If the cache file does not exist, exit immediately with no error log // If the cache file does not exist, exit immediately with no error log
if _, err := os.Stat(c.path); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(c.path); errors.Is(err, os.ErrNotExist) {
return nil return nil
} }
// Read the cache and lookup the matching entry. If one exists, update its last used timestamp and return it. // Read the cache and lookup the matching entry. If one exists, update its last used timestamp and return it.
var result *oidcclient.Token var result *oidctypes.Token
c.withCache(func(cache *sessionCache) { c.withCache(func(cache *sessionCache) {
if entry := cache.lookup(key); entry != nil { if entry := cache.lookup(key); entry != nil {
result = &entry.Tokens result = &entry.Tokens
@ -84,7 +85,7 @@ func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token {
// PutToken stores the provided token into the session cache under the given parameters. It does not return an error // PutToken stores the provided token into the session cache under the given parameters. It does not return an error
// but may silently fail to update the session cache. // but may silently fail to update the session cache.
func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidcclient.Token) { func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidctypes.Token) {
// Create the cache directory if it does not exist. // Create the cache directory if it does not exist.
if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil && !errors.Is(err, os.ErrExist) { if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil && !errors.Is(err, os.ErrExist) {
c.errReporter(fmt.Errorf("could not create session cache directory: %w", err)) c.errReporter(fmt.Errorf("could not create session cache directory: %w", err))

View File

@ -17,6 +17,7 @@ import (
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -38,7 +39,7 @@ func TestGetToken(t *testing.T) {
trylockFunc func(*testing.T) error trylockFunc func(*testing.T) error
unlockFunc func(*testing.T) error unlockFunc func(*testing.T) error
key oidcclient.SessionCacheKey key oidcclient.SessionCacheKey
want *oidcclient.Token want *oidctypes.Token
wantErrors []string wantErrors []string
wantTestFile func(t *testing.T, tmp string) wantTestFile func(t *testing.T, tmp string)
}{ }{
@ -99,17 +100,17 @@ func TestGetToken(t *testing.T) {
}, },
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token", Token: "test-refresh-token",
}, },
}, },
@ -137,17 +138,17 @@ func TestGetToken(t *testing.T) {
}, },
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token", Token: "test-refresh-token",
}, },
}, },
@ -161,17 +162,17 @@ func TestGetToken(t *testing.T) {
RedirectURI: "http://localhost:0/callback", RedirectURI: "http://localhost:0/callback",
}, },
wantErrors: []string{}, wantErrors: []string{},
want: &oidcclient.Token{ want: &oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "test-access-token", Token: "test-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "test-refresh-token", Token: "test-refresh-token",
}, },
}, },
@ -219,7 +220,7 @@ func TestPutToken(t *testing.T) {
name string name string
makeTestFile func(t *testing.T, tmp string) makeTestFile func(t *testing.T, tmp string)
key oidcclient.SessionCacheKey key oidcclient.SessionCacheKey
token *oidcclient.Token token *oidctypes.Token
wantErrors []string wantErrors []string
wantTestFile func(t *testing.T, tmp string) wantTestFile func(t *testing.T, tmp string)
}{ }{
@ -245,17 +246,17 @@ func TestPutToken(t *testing.T) {
}, },
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "old-access-token", Token: "old-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "old-id-token", Token: "old-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "old-refresh-token", Token: "old-refresh-token",
}, },
}, },
@ -269,17 +270,17 @@ func TestPutToken(t *testing.T) {
Scopes: []string{"email", "offline_access", "openid", "profile"}, Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback", RedirectURI: "http://localhost:0/callback",
}, },
token: &oidcclient.Token{ token: &oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "new-access-token", Token: "new-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "new-id-token", Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "new-refresh-token", Token: "new-refresh-token",
}, },
}, },
@ -288,17 +289,17 @@ func TestPutToken(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, cache.Sessions, 1) require.Len(t, cache.Sessions, 1)
require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds()) require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
require.Equal(t, oidcclient.Token{ require.Equal(t, oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "new-access-token", Token: "new-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "new-id-token", Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "new-refresh-token", Token: "new-refresh-token",
}, },
}, cache.Sessions[0].Tokens) }, cache.Sessions[0].Tokens)
@ -317,17 +318,17 @@ func TestPutToken(t *testing.T) {
}, },
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)), CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)), LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{ Tokens: oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "old-access-token", Token: "old-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "old-id-token", Token: "old-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)), Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "old-refresh-token", Token: "old-refresh-token",
}, },
}, },
@ -341,17 +342,17 @@ func TestPutToken(t *testing.T) {
Scopes: []string{"email", "offline_access", "openid", "profile"}, Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback", RedirectURI: "http://localhost:0/callback",
}, },
token: &oidcclient.Token{ token: &oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "new-access-token", Token: "new-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "new-id-token", Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "new-refresh-token", Token: "new-refresh-token",
}, },
}, },
@ -360,17 +361,17 @@ func TestPutToken(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, cache.Sessions, 2) require.Len(t, cache.Sessions, 2)
require.Less(t, time.Since(cache.Sessions[1].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds()) require.Less(t, time.Since(cache.Sessions[1].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
require.Equal(t, oidcclient.Token{ require.Equal(t, oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "new-access-token", Token: "new-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "new-id-token", Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "new-refresh-token", Token: "new-refresh-token",
}, },
}, cache.Sessions[1].Tokens) }, cache.Sessions[1].Tokens)
@ -389,17 +390,17 @@ func TestPutToken(t *testing.T) {
Scopes: []string{"email", "offline_access", "openid", "profile"}, Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback", RedirectURI: "http://localhost:0/callback",
}, },
token: &oidcclient.Token{ token: &oidctypes.Token{
AccessToken: &oidcclient.AccessToken{ AccessToken: &oidctypes.AccessToken{
Token: "new-access-token", Token: "new-access-token",
Type: "Bearer", Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
IDToken: &oidcclient.IDToken{ IDToken: &oidctypes.IDToken{
Token: "new-id-token", Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()), Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
}, },
RefreshToken: &oidcclient.RefreshToken{ RefreshToken: &oidctypes.RefreshToken{
Token: "new-refresh-token", Token: "new-refresh-token",
}, },
}, },

View File

@ -16,11 +16,13 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/pkg/browser" "github.com/pkg/browser"
"golang.org/x/oauth2" "golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/upstreamoidc"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
"go.pinniped.dev/pkg/oidcclient/state" "go.pinniped.dev/pkg/oidcclient/state"
) )
@ -51,7 +53,7 @@ type handlerState struct {
callbackPath string callbackPath string
// Generated parameters of a login flow. // Generated parameters of a login flow.
idTokenVerifier *oidc.IDTokenVerifier provider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
state state.State state state.State
nonce nonce.Nonce nonce nonce.Nonce
@ -62,13 +64,13 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
oidcDiscover func(context.Context, string) (discoveryI, error) getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
callbacks chan callbackResult callbacks chan callbackResult
} }
type callbackResult struct { type callbackResult struct {
token *Token token *oidctypes.Token
err error err error
} }
@ -87,6 +89,7 @@ func WithContext(ctx context.Context) Option {
// WithListenPort specifies a TCP listen port on localhost, which will be used for the redirect_uri and to handle the // WithListenPort specifies a TCP listen port on localhost, which will be used for the redirect_uri and to handle the
// authorization code callback. By default, a random high port will be chosen which requires the authorization server // authorization code callback. By default, a random high port will be chosen which requires the authorization server
// to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252: // to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252:
//
// The authorization server MUST allow any port to be specified at the // The authorization server MUST allow any port to be specified at the
// time of the request for loopback IP redirect URIs, to accommodate // time of the request for loopback IP redirect URIs, to accommodate
// clients that obtain an available ephemeral port from the operating // clients that obtain an available ephemeral port from the operating
@ -116,6 +119,19 @@ func WithBrowserOpen(openURL func(url string) error) Option {
} }
} }
// SessionCacheKey contains the data used to select a valid session cache entry.
type SessionCacheKey struct {
Issuer string `json:"issuer"`
ClientID string `json:"clientID"`
Scopes []string `json:"scopes"`
RedirectURI string `json:"redirect_uri"`
}
type SessionCache interface {
GetToken(SessionCacheKey) *oidctypes.Token
PutToken(SessionCacheKey, *oidctypes.Token)
}
// WithSessionCache sets the session cache backend for storing and retrieving previously-issued ID tokens and refresh tokens. // WithSessionCache sets the session cache backend for storing and retrieving previously-issued ID tokens and refresh tokens.
func WithSessionCache(cache SessionCache) Option { func WithSessionCache(cache SessionCache) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
@ -135,16 +151,11 @@ func WithClient(httpClient *http.Client) Option {
// nopCache is a SessionCache that doesn't actually do anything. // nopCache is a SessionCache that doesn't actually do anything.
type nopCache struct{} type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil } func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *Token) {} func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
type discoveryI interface {
Endpoint() oauth2.Endpoint
Verifier(*oidc.Config) *oidc.IDTokenVerifier
}
// Login performs an OAuth2/OIDC authorization code login using a localhost listener. // Login performs an OAuth2/OIDC authorization code login using a localhost listener.
func Login(issuer string, clientID string, opts ...Option) (*Token, error) { func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
h := handlerState{ h := handlerState{
issuer: issuer, issuer: issuer,
clientID: clientID, clientID: clientID,
@ -161,9 +172,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
generateNonce: nonce.Generate, generateNonce: nonce.Generate,
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { getProvider: upstreamoidc.New,
return oidc.NewProvider(ctx, iss)
},
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -208,16 +217,15 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
} }
// Perform OIDC discovery. // Perform OIDC discovery.
discovered, err := h.oidcDiscover(h.ctx, h.issuer) h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
} }
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{ h.oauth2Config = &oauth2.Config{
ClientID: h.clientID, ClientID: h.clientID,
Endpoint: discovered.Endpoint(), Endpoint: h.provider.Endpoint(),
Scopes: h.scopes, Scopes: h.scopes,
} }
@ -274,7 +282,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
} }
} }
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) { func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
ctx, cancel := context.WithTimeout(ctx, refreshTimeout) ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
defer cancel() defer cancel()
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token}) refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
@ -287,7 +295,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshT
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations). // some providers do not include one, so we skip the nonce validation here (but not other validations).
return h.validateToken(ctx, refreshed, false) token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "")
if err != nil {
return nil, err
}
return &token, nil
} }
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
@ -314,58 +326,25 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam) return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
} }
// Exchange the authorization code for access, ID, and refresh tokens. // Exchange the authorization code for access, ID, and refresh tokens and perform required
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier()) // validations on the returned ID token.
token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).
ExchangeAuthcodeAndValidateTokens(
r.Context(),
params.Get("code"),
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }
// Perform required validations on the returned ID token. h.callbacks <- callbackResult{token: &token}
token, err := h.validateToken(r.Context(), oauth2Tok, true)
if err != nil {
return err
}
h.callbacks <- callbackResult{token: token}
_, _ = w.Write([]byte("you have been logged in and may now close this tab")) _, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil return nil
} }
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
if err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if checkNonce {
if err := h.nonce.Validate(validated); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
return &Token{
AccessToken: &AccessToken{
Token: tok.AccessToken,
Type: tok.TokenType,
Expiry: metav1.NewTime(tok.Expiry),
},
RefreshToken: &RefreshToken{
Token: tok.RefreshToken,
},
IDToken: &IDToken{
Token: idTok,
Expiry: metav1.NewTime(validated.Expiry),
},
}, nil
}
func (h *handlerState) serve(listener net.Listener) func() { func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))

View File

@ -18,12 +18,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
"go.pinniped.dev/pkg/oidcclient/state" "go.pinniped.dev/pkg/oidcclient/state"
) )
@ -31,19 +33,19 @@ import (
// mockSessionCache exists to avoid an import cycle if we generate mocks into another package. // mockSessionCache exists to avoid an import cycle if we generate mocks into another package.
type mockSessionCache struct { type mockSessionCache struct {
t *testing.T t *testing.T
getReturnsToken *Token getReturnsToken *oidctypes.Token
sawGetKeys []SessionCacheKey sawGetKeys []SessionCacheKey
sawPutKeys []SessionCacheKey sawPutKeys []SessionCacheKey
sawPutTokens []*Token sawPutTokens []*oidctypes.Token
} }
func (m *mockSessionCache) GetToken(key SessionCacheKey) *Token { func (m *mockSessionCache) GetToken(key SessionCacheKey) *oidctypes.Token {
m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID) m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID)
m.sawGetKeys = append(m.sawGetKeys, key) m.sawGetKeys = append(m.sawGetKeys, key)
return m.getReturnsToken return m.getReturnsToken
} }
func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) { func (m *mockSessionCache) PutToken(key SessionCacheKey, token *oidctypes.Token) {
m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token) m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token)
m.sawPutKeys = append(m.sawPutKeys, key) m.sawPutKeys = append(m.sawPutKeys, key)
m.sawPutTokens = append(m.sawPutTokens, token) m.sawPutTokens = append(m.sawPutTokens, token)
@ -54,20 +56,10 @@ func TestLogin(t *testing.T) {
time1Unix := int64(2075807775) time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
testToken := Token{ testToken := oidctypes.Token{
AccessToken: &AccessToken{ AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
Token: "test-access-token", RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
Expiry: metav1.NewTime(time1.Add(1 * time.Minute)), IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
},
RefreshToken: &RefreshToken{
Token: "test-refresh-token",
},
IDToken: &IDToken{
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
},
} }
// Start a test server that returns 500 errors // Start a test server that returns 500 errors
@ -76,7 +68,7 @@ func TestLogin(t *testing.T) {
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns a real keyset and answers refresh requests. // Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux() providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux) successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close) t.Cleanup(successServer.Close)
@ -144,7 +136,7 @@ func TestLogin(t *testing.T) {
issuer string issuer string
clientID string clientID string
wantErr string wantErr string
wantToken *Token wantToken *oidctypes.Token
}{ }{
{ {
name: "option error", name: "option error",
@ -191,8 +183,8 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{ cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &IDToken{ IDToken: &oidctypes.IDToken{
Token: "test-id-token", Token: "test-id-token",
Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity
}, },
@ -246,12 +238,20 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{ h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
IDToken: &IDToken{ mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(testToken, nil, nil)
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
}, },
RefreshToken: &RefreshToken{Token: "test-refresh-token"}, RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}} }}
t.Cleanup(func() { t.Cleanup(func() {
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -266,12 +266,6 @@ func TestLogin(t *testing.T) {
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil return nil
} }
}, },
@ -283,12 +277,20 @@ func TestLogin(t *testing.T) {
clientID: "test-client-id", clientID: "test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{ h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
IDToken: &IDToken{ mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error"))
return mock
}
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
}, },
RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"}, RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
}} }}
t.Cleanup(func() { t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys) require.Empty(t, cache.sawPutKeys)
@ -296,16 +298,10 @@ func TestLogin(t *testing.T) {
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil return nil
} }
}, },
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", wantErr: "some validation error",
}, },
{ {
name: "session cache hit but refresh fails", name: "session cache hit but refresh fails",
@ -313,12 +309,12 @@ func TestLogin(t *testing.T) {
clientID: "not-the-test-client-id", clientID: "not-the-test-client-id",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{ cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &IDToken{ IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token", Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
}, },
RefreshToken: &RefreshToken{Token: "test-refresh-token"}, RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}} }}
t.Cleanup(func() { t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys) require.Empty(t, cache.sawPutKeys)
@ -326,12 +322,6 @@ func TestLogin(t *testing.T) {
}) })
h.cache = cache h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
h.listenAddr = "invalid-listen-address" h.listenAddr = "invalid-listen-address"
return nil return nil
@ -413,7 +403,7 @@ func TestLogin(t *testing.T) {
t.Cleanup(func() { t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys) require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*Token{&testToken}, cache.sawPutTokens) require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
}) })
require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h)) require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
@ -481,7 +471,7 @@ func TestLogin(t *testing.T) {
require.NotNil(t, tok.AccessToken) require.NotNil(t, tok.AccessToken)
require.Equal(t, want.Token, tok.AccessToken.Token) require.Equal(t, want.Token, tok.AccessToken.Token)
require.Equal(t, want.Type, tok.AccessToken.Type) require.Equal(t, want.Type, tok.AccessToken.Type)
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second) testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
} else { } else {
assert.Nil(t, tok.AccessToken) assert.Nil(t, tok.AccessToken)
} }
@ -489,7 +479,7 @@ func TestLogin(t *testing.T) {
if want := tt.wantToken.IDToken; want != nil { if want := tt.wantToken.IDToken; want != nil {
require.NotNil(t, tok.IDToken) require.NotNil(t, tok.IDToken)
require.Equal(t, want.Token, tok.IDToken.Token) require.Equal(t, want.Token, tok.IDToken.Token)
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second) testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
} else { } else {
assert.Nil(t, tok.IDToken) assert.Nil(t, tok.IDToken)
} }
@ -498,11 +488,13 @@ func TestLogin(t *testing.T) {
} }
func TestHandleAuthCodeCallback(t *testing.T) { func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
tests := []struct { tests := []struct {
name string name string
method string method string
query string query string
returnIDTok string opt func(t *testing.T) Option
wantErr string wantErr string
wantHTTPStatus int wantHTTPStatus int
}{ }{
@ -528,94 +520,51 @@ func TestHandleAuthCodeCallback(t *testing.T) {
{ {
name: "invalid code", name: "invalid code",
query: "state=test-state&code=invalid", query: "state=test-state&code=invalid",
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n", wantErr: "could not complete code exchange: some exchange error",
wantHTTPStatus: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
return mock
}
return nil
}
}, },
{
name: "missing ID token",
query: "state=test-state&code=valid",
returnIDTok: "",
wantErr: "received response missing ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid ID token",
query: "state=test-state&code=valid",
returnIDTok: "invalid-jwt",
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid access token hash",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
wantErr: "received invalid ID token: access token hash does not match value in ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid nonce",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
wantHTTPStatus: http.StatusBadRequest,
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
}, },
{ {
name: "valid", name: "valid",
query: "state=test-state&code=valid", query: "state=test-state&code=valid",
opt: func(t *testing.T) Option {
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/: return func(h *handlerState) error {
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q", h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
return mock
}
return nil
}
},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.NoError(t, r.ParseForm())
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
require.NotEmpty(t, r.Form.Get("code"))
if r.Form.Get("code") != "valid" {
http.Error(w, "invalid authorization code", http.StatusForbidden)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
}
response.AccessToken = "test-access-token"
response.Expiry = time.Now().Add(time.Hour)
response.IDToken = tt.returnIDTok
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}))
t.Cleanup(tokenServer.Close)
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
pkce: pkce.Code("test-pkce"), pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"), nonce: nonce.Nonce("test-nonce"),
oauth2Config: &oauth2.Config{ }
ClientID: "test-client-id", if tt.opt != nil {
RedirectURL: "http://localhost:12345/callback", require.NoError(t, tt.opt(t)(h))
Endpoint: oauth2.Endpoint{
TokenURL: tokenServer.URL,
AuthStyle: oauth2.AuthStyleInParams,
},
},
idTokenVerifier: mockVerifier(),
} }
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@ -651,47 +600,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
} }
require.NoError(t, result.err) require.NoError(t, result.err)
require.NotNil(t, result.token) require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken.Token, tt.returnIDTok) require.Equal(t, result.token.IDToken.Token, "test-id-token")
} }
}) })
} }
} }
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else. func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
func mockVerifier() *oidc.IDTokenVerifier { t.Helper()
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil)) ctrl := gomock.NewController(t)
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()). t.Cleanup(ctrl.Finish)
AnyTimes(). return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
return jws.UnsafePayloadWithoutVerification(), nil
})
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
SkipIssuerCheck: true,
SkipExpiryCheck: true,
SkipClientIDCheck: true,
})
} }
type mockDiscovery struct{ provider *oidc.Provider } // hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
type hasAccessTokenMatcher struct{ expected string }
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
return arg.(*oauth2.Token).AccessToken == m.expected
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } }
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) { func (m hasAccessTokenMatcher) Got(got interface{}) string {
require.InDeltaf(t, return got.(*oauth2.Token).AccessToken
float64(t1.UnixNano()), }
float64(t2.UnixNano()),
float64(delta.Nanoseconds()), func (m hasAccessTokenMatcher) String() string {
"expected %s and %s to be < %s apart, but they are %s apart", return m.expected
t1.Format(time.RFC3339Nano), }
t2.Format(time.RFC3339Nano),
delta.String(), func HasAccessToken(expected string) gomock.Matcher {
t1.Sub(t2).String(), return hasAccessTokenMatcher{expected: expected}
)
} }

View File

@ -1,11 +1,10 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package oidcclient // Package oidctypes provides core data types for OIDC token structures.
package oidctypes
import ( import v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// AccessToken is an OAuth2 access token. // AccessToken is an OAuth2 access token.
type AccessToken struct { type AccessToken struct {
@ -16,7 +15,7 @@ type AccessToken struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
// Expiry is the optional expiration time of the access token. // Expiry is the optional expiration time of the access token.
Expiry metav1.Time `json:"expiryTimestamp,omitempty"` Expiry v1.Time `json:"expiryTimestamp,omitempty"`
} }
// RefreshToken is an OAuth2 refresh token. // RefreshToken is an OAuth2 refresh token.
@ -31,7 +30,7 @@ type IDToken struct {
Token string `json:"token"` Token string `json:"token"`
// Expiry is the optional expiration time of the ID token. // Expiry is the optional expiration time of the ID token.
Expiry metav1.Time `json:"expiryTimestamp,omitempty"` Expiry v1.Time `json:"expiryTimestamp,omitempty"`
} }
// Token contains the elements of an OIDC session. // Token contains the elements of an OIDC session.
@ -47,16 +46,3 @@ type Token struct {
// IDToken is an OpenID Connect ID token. // IDToken is an OpenID Connect ID token.
IDToken *IDToken `json:"id,omitempty"` IDToken *IDToken `json:"id,omitempty"`
} }
// SessionCacheKey contains the data used to select a valid session cache entry.
type SessionCacheKey struct {
Issuer string `json:"issuer"`
ClientID string `json:"clientID"`
Scopes []string `json:"scopes"`
RedirectURI string `json:"redirect_uri"`
}
type SessionCache interface {
GetToken(SessionCacheKey) *Token
PutToken(SessionCacheKey, *Token)
}

View File

@ -28,8 +28,7 @@ staticClients:
name: 'Pinniped Supervisor' name: 'Pinniped Supervisor'
secret: pinniped-supervisor-secret secret: pinniped-supervisor-secret
redirectURIs: redirectURIs:
- #@ "http://127.0.0.1:" + str(data.values.ports.cli) + "/callback" - https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback
- #@ "http://[::1]:" + str(data.values.ports.cli) + "/callback"
enablePasswordDB: true enablePasswordDB: true
staticPasswords: staticPasswords:
- username: "pinny" - username: "pinny"

View File

@ -20,7 +20,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/sclevine/agouti"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -30,6 +29,7 @@ import (
"go.pinniped.dev/pkg/oidcclient" "go.pinniped.dev/pkg/oidcclient"
"go.pinniped.dev/pkg/oidcclient/filesession" "go.pinniped.dev/pkg/oidcclient/filesession"
"go.pinniped.dev/test/library" "go.pinniped.dev/test/library"
"go.pinniped.dev/test/library/browsertest"
) )
func TestCLIGetKubeconfig(t *testing.T) { func TestCLIGetKubeconfig(t *testing.T) {
@ -108,80 +108,14 @@ func runPinnipedCLIGetKubeconfig(t *testing.T, pinnipedExe, token, namespaceName
return string(output) return string(output)
} }
type loginProviderPatterns struct {
Name string
IssuerPattern *regexp.Regexp
LoginPagePattern *regexp.Regexp
UsernameSelector string
PasswordSelector string
LoginButtonSelector string
}
func getLoginProvider(t *testing.T) *loginProviderPatterns {
t.Helper()
issuer := library.IntegrationEnv(t).CLITestUpstream.Issuer
for _, p := range []loginProviderPatterns{
{
Name: "Okta",
IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
UsernameSelector: "input#okta-signin-username",
PasswordSelector: "input#okta-signin-password",
LoginButtonSelector: "input#okta-signin-submit",
},
{
Name: "Dex",
IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`),
LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`),
UsernameSelector: "input#login",
PasswordSelector: "input#password",
LoginButtonSelector: "button#submit-login",
},
} {
if p.IssuerPattern.MatchString(issuer) {
return &p
}
}
require.Failf(t, "could not find login provider for issuer %q", issuer)
return nil
}
func TestCLILoginOIDC(t *testing.T) { func TestCLILoginOIDC(t *testing.T) {
env := library.IntegrationEnv(t) env := library.IntegrationEnv(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
// Find the login CSS selectors for the test issuer, or fail fast.
loginProvider := getLoginProvider(t)
// Start the browser driver. // Start the browser driver.
t.Logf("opening browser driver") page := browsertest.Open(t)
caps := agouti.NewCapabilities()
if env.Proxy != "" {
t.Logf("configuring Chrome to use proxy %q", env.Proxy)
caps = caps.Proxy(agouti.ProxyConfig{
ProxyType: "manual",
HTTPProxy: env.Proxy,
SSLProxy: env.Proxy,
NoProxy: "127.0.0.1",
})
}
agoutiDriver := agouti.ChromeDriver(
agouti.Desired(caps),
agouti.ChromeOptions("args", []string{
"--no-sandbox",
"--ignore-certificate-errors",
"--headless", // Comment out this line to see the tests happen in a visible browser window.
}),
// Uncomment this to see stdout/stderr from chromedriver.
// agouti.Debug,
)
require.NoError(t, agoutiDriver.Start())
t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) })
page, err := agoutiDriver.NewPage(agouti.Browser("chrome"))
require.NoError(t, err)
require.NoError(t, page.Reset())
// Build pinniped CLI. // Build pinniped CLI.
t.Logf("building CLI binary") t.Logf("building CLI binary")
@ -262,28 +196,18 @@ func TestCLILoginOIDC(t *testing.T) {
t.Logf("navigating to login page") t.Logf("navigating to login page")
require.NoError(t, page.Navigate(loginURL)) require.NoError(t, page.Navigate(loginURL))
// Expect to be redirected to the login page. // Expect to be redirected to the upstream provider and log in.
t.Logf("waiting for redirect to %s login page", loginProvider.Name) browsertest.LoginToUpstream(t, page, env.CLITestUpstream)
waitForURL(t, page, loginProvider.LoginPagePattern)
// Wait for the login page to be rendered. // Expect to be redirected to the localhost callback.
waitForVisibleElements(t, page, loginProvider.UsernameSelector, loginProvider.PasswordSelector, loginProvider.LoginButtonSelector) t.Logf("waiting for redirect to callback")
// Fill in the username and password and click "submit".
t.Logf("logging into %s", loginProvider.Name)
require.NoError(t, page.First(loginProvider.UsernameSelector).Fill(env.CLITestUpstream.Username))
require.NoError(t, page.First(loginProvider.PasswordSelector).Fill(env.CLITestUpstream.Password))
require.NoError(t, page.First(loginProvider.LoginButtonSelector).Click())
// Wait for the login to happen and us be redirected back to a localhost callback.
t.Logf("waiting for redirect to localhost callback")
callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(env.CLITestUpstream.CallbackURL) + `\?.+\z`) callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(env.CLITestUpstream.CallbackURL) + `\?.+\z`)
waitForURL(t, page, callbackURLPattern) browsertest.WaitForURL(t, page, callbackURLPattern)
// Wait for the "pre" element that gets rendered for a `text/plain` page, and // Wait for the "pre" element that gets rendered for a `text/plain` page, and
// assert that it contains the success message. // assert that it contains the success message.
t.Logf("verifying success page") t.Logf("verifying success page")
waitForVisibleElements(t, page, "pre") browsertest.WaitForVisibleElements(t, page, "pre")
msg, err := page.First("pre").Text() msg, err := page.First("pre").Text()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "you have been logged in and may now close this tab", msg) require.Equal(t, "you have been logged in and may now close this tab", msg)
@ -361,44 +285,6 @@ func TestCLILoginOIDC(t *testing.T) {
require.NotEqual(t, credOutput2.Status.Token, credOutput3.Status.Token) require.NotEqual(t, credOutput2.Status.Token, credOutput3.Status.Token)
} }
func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) {
t.Helper()
require.Eventually(t,
func() bool {
for _, sel := range selectors {
vis, err := page.First(sel).Visible()
if !(err == nil && vis) {
return false
}
}
return true
},
10*time.Second,
100*time.Millisecond,
)
}
func waitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) {
var lastURL string
require.Eventuallyf(t,
func() bool {
url, err := page.URL()
if err == nil && pat.MatchString(url) {
return true
}
if url != lastURL {
t.Logf("saw URL %s", url)
lastURL = url
}
return false
},
10*time.Second,
100*time.Millisecond,
"expected to browse to %s, but never got there",
pat,
)
}
func readAndExpectEmpty(r io.Reader) (err error) { func readAndExpectEmpty(r io.Reader) (err error) {
var remainder bytes.Buffer var remainder bytes.Buffer
_, err = io.Copy(&remainder, r) _, err = io.Copy(&remainder, r)

View File

@ -17,7 +17,7 @@ import (
"k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/fosite/authorizationcode" "go.pinniped.dev/internal/fositestorage/authorizationcode"
"go.pinniped.dev/test/library" "go.pinniped.dev/test/library"
) )
@ -29,7 +29,7 @@ func TestAuthorizeCodeStorage(t *testing.T) {
// randomly generated HMAC authorization code (see below) // randomly generated HMAC authorization code (see below)
code = "TQ72B8YjdEOZyxridYbTLE-pzoK4hpdkZxym5j4EmSc.TKRTgQG41IBQ16FDKTthRdhXfLlNaErcMd9Fy47uXAw" code = "TQ72B8YjdEOZyxridYbTLE-pzoK4hpdkZxym5j4EmSc.TKRTgQG41IBQ16FDKTthRdhXfLlNaErcMd9Fy47uXAw"
// name of the secret that will be created in Kube // name of the secret that will be created in Kube
name = "pinniped-storage-authorization-codes-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga" name = "pinniped-storage-authcode-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga"
) )
hmac := compose.NewOAuth2HMACStrategy(&compose.Config{}, []byte("super-secret-32-byte-for-testing"), nil) hmac := compose.NewOAuth2HMACStrategy(&compose.Config{}, []byte("super-secret-32-byte-for-testing"), nil)

View File

@ -111,7 +111,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) {
// When the same issuer is added twice, both issuers are marked as duplicates, and neither provider is serving. // When the same issuer is added twice, both issuers are marked as duplicates, and neither provider is serving.
config6Duplicate1, _ := requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(ctx, t, scheme, addr, caBundle, issuer6, client) config6Duplicate1, _ := requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(ctx, t, scheme, addr, caBundle, issuer6, client)
config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "") config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "", "")
requireStatus(t, client, ns, config6Duplicate1.Name, v1alpha1.DuplicateOIDCProviderStatusCondition) requireStatus(t, client, ns, config6Duplicate1.Name, v1alpha1.DuplicateOIDCProviderStatusCondition)
requireStatus(t, client, ns, config6Duplicate2.Name, v1alpha1.DuplicateOIDCProviderStatusCondition) requireStatus(t, client, ns, config6Duplicate2.Name, v1alpha1.DuplicateOIDCProviderStatusCondition)
requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, issuer6) requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, issuer6)
@ -136,7 +136,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) {
} }
// When we create a provider with an invalid issuer, the status is set to invalid. // When we create a provider with an invalid issuer, the status is set to invalid.
badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "") badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "", "")
requireStatus(t, client, ns, badConfig.Name, v1alpha1.InvalidOIDCProviderStatusCondition) requireStatus(t, client, ns, badConfig.Name, v1alpha1.InvalidOIDCProviderStatusCondition)
requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, badIssuer) requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, badIssuer)
requireDeletingOIDCProviderCausesDiscoveryEndpointsToDisappear(t, badConfig, client, ns, scheme, addr, caBundle, badIssuer) requireDeletingOIDCProviderCausesDiscoveryEndpointsToDisappear(t, badConfig, client, ns, scheme, addr, caBundle, badIssuer)
@ -162,7 +162,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) {
certSecretName1 := "integration-test-cert-1" certSecretName1 := "integration-test-cert-1"
// Create an OIDCProvider with a spec.tls.secretName. // Create an OIDCProvider with a spec.tls.secretName.
oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1) oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1, "")
requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition) requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
// The spec.tls.secretName Secret does not exist, so the endpoints should fail with TLS errors. // The spec.tls.secretName Secret does not exist, so the endpoints should fail with TLS errors.
@ -198,7 +198,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) {
certSecretName2 := "integration-test-cert-2" certSecretName2 := "integration-test-cert-2"
// Create an OIDCProvider with a spec.tls.secretName. // Create an OIDCProvider with a spec.tls.secretName.
oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2) oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2, "")
requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition) requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
// Create the Secret. // Create the Secret.
@ -232,31 +232,30 @@ func TestSupervisorTLSTerminationWithDefaultCerts(t *testing.T) {
port = hostAndPortSegments[1] port = hostAndPortSegments[1]
} }
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname) ips, err := library.LookupIP(ctx, hostname)
require.NoError(t, err) require.NoError(t, err)
ip := ips[0] require.NotEmpty(t, ips)
ipAsString := ip.String() ipWithPort := ips[0].String() + ":" + port
ipWithPort := ipAsString + ":" + port
issuerUsingIPAddress := fmt.Sprintf("%s://%s/issuer1", scheme, ipWithPort) issuerUsingIPAddress := fmt.Sprintf("%s://%s/issuer1", scheme, ipWithPort)
issuerUsingHostname := fmt.Sprintf("%s://%s/issuer1", scheme, address) issuerUsingHostname := fmt.Sprintf("%s://%s/issuer1", scheme, address)
// Create an OIDCProvider without a spec.tls.secretName. // Create an OIDCProvider without a spec.tls.secretName.
oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "") oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "", "")
requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition) requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
// There is no default TLS cert and the spec.tls.secretName was not set, so the endpoints should fail with TLS errors. // There is no default TLS cert and the spec.tls.secretName was not set, so the endpoints should fail with TLS errors.
requireEndpointHasTLSErrorBecauseCertificatesAreNotReady(t, issuerUsingIPAddress) requireEndpointHasTLSErrorBecauseCertificatesAreNotReady(t, issuerUsingIPAddress)
// Create a Secret at the special name which represents the default TLS cert. // Create a Secret at the special name which represents the default TLS cert.
defaultCA := createTLSCertificateSecret(ctx, t, ns, "cert-hostname-doesnt-matter", []net.IP{ip.IP}, defaultTLSCertSecretName(env), kubeClient) defaultCA := createTLSCertificateSecret(ctx, t, ns, "cert-hostname-doesnt-matter", []net.IP{ips[0]}, defaultTLSCertSecretName(env), kubeClient)
// Now that the Secret exists, we should be able to access the endpoints by IP address using the CA. // Now that the Secret exists, we should be able to access the endpoints by IP address using the CA.
_ = requireDiscoveryEndpointsAreWorking(t, scheme, ipWithPort, string(defaultCA.Bundle()), issuerUsingIPAddress, nil) _ = requireDiscoveryEndpointsAreWorking(t, scheme, ipWithPort, string(defaultCA.Bundle()), issuerUsingIPAddress, nil)
// Create an OIDCProvider with a spec.tls.secretName. // Create an OIDCProvider with a spec.tls.secretName.
certSecretName := "integration-test-cert-1" certSecretName := "integration-test-cert-1"
oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName) oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName, "")
requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition) requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
// Create the Secret. // Create the Secret.
@ -429,7 +428,7 @@ func requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(
client pinnipedclientset.Interface, client pinnipedclientset.Interface,
) (*v1alpha1.OIDCProvider, *ExpectedJWKSResponseFormat) { ) (*v1alpha1.OIDCProvider, *ExpectedJWKSResponseFormat) {
t.Helper() t.Helper()
newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "") newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "", "")
jwksResult := requireDiscoveryEndpointsAreWorking(t, supervisorScheme, supervisorAddress, supervisorCABundle, issuerName, nil) jwksResult := requireDiscoveryEndpointsAreWorking(t, supervisorScheme, supervisorAddress, supervisorCABundle, issuerName, nil)
requireStatus(t, client, newOIDCProvider.Namespace, newOIDCProvider.Name, v1alpha1.SuccessOIDCProviderStatusCondition) requireStatus(t, client, newOIDCProvider.Namespace, newOIDCProvider.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
return newOIDCProvider, jwksResult return newOIDCProvider, jwksResult

View File

@ -27,7 +27,7 @@ func TestSupervisorOIDCKeys(t *testing.T) {
defer cancel() defer cancel()
// Create our OPC under test. // Create our OPC under test.
opc := library.CreateTestOIDCProvider(ctx, t, "", "") opc := library.CreateTestOIDCProvider(ctx, t, "", "", "")
// Ensure a secret is created with the OPC's JWKS. // Ensure a secret is created with the OPC's JWKS.
var updatedOPC *configv1alpha1.OIDCProvider var updatedOPC *configv1alpha1.OIDCProvider

View File

@ -6,236 +6,180 @@ package integration
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509/pkix"
"encoding/base64" "encoding/base64"
"fmt"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"path" "regexp"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1"
idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/certauthority"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
"go.pinniped.dev/pkg/oidcclient/state" "go.pinniped.dev/pkg/oidcclient/state"
"go.pinniped.dev/test/library" "go.pinniped.dev/test/library"
"go.pinniped.dev/test/library/browsertest"
) )
func TestSupervisorLogin(t *testing.T) { func TestSupervisorLogin(t *testing.T) {
t.Skip("waiting on new callback path logic to get merged in from the callback endpoint work")
env := library.IntegrationEnv(t) env := library.IntegrationEnv(t)
client := library.NewSupervisorClientset(t)
// If anything in this test crashes, dump out the supervisor pod logs.
defer library.DumpLogs(t, env.SupervisorNamespace)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() defer cancel()
tests := []struct { // Infer the downstream issuer URL from the callback associated with the upstream test client registration.
Scheme string issuerURL, err := url.Parse(env.SupervisorTestUpstream.CallbackURL)
Address string require.NoError(t, err)
CABundle string require.True(t, strings.HasSuffix(issuerURL.Path, "/callback"))
}{ issuerURL.Path = strings.TrimSuffix(issuerURL.Path, "/callback")
{Scheme: "http", Address: env.SupervisorHTTPAddress}, t.Logf("testing with downstream issuer URL %s", issuerURL.String())
{Scheme: "https", Address: env.SupervisorHTTPSIngressAddress, CABundle: env.SupervisorHTTPSIngressCABundle},
// Generate a CA bundle with which to serve this provider.
t.Logf("generating test CA")
ca, err := certauthority.New(pkix.Name{CommonName: "Downstream Test CA"}, 1*time.Hour)
require.NoError(t, err)
// Create an HTTP client that can reach the downstream discovery endpoint using the CA certs.
httpClient := &http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: ca.Pool()},
Proxy: func(req *http.Request) (*url.URL, error) {
if env.Proxy == "" {
return nil, nil
} }
return url.Parse(env.Proxy)
},
}}
for _, test := range tests { // Use the CA to issue a TLS server cert.
scheme := test.Scheme t.Logf("issuing test certificate")
addr := test.Address tlsCert, err := ca.Issue(
caBundle := test.CABundle pkix.Name{CommonName: issuerURL.Hostname()},
[]string{issuerURL.Hostname()},
nil,
1*time.Hour,
)
require.NoError(t, err)
certPEM, keyPEM, err := certauthority.ToPEM(tlsCert)
require.NoError(t, err)
if addr == "" { // Write the serving cert to a secret.
// Both cases are not required, so when one is empty skip it. certSecret := library.CreateTestSecret(t,
continue env.SupervisorNamespace,
} "oidc-provider-tls",
"kubernetes.io/tls",
// Create downstream OIDC provider (i.e., update supervisor with OIDC provider). map[string]string{"tls.crt": string(certPEM), "tls.key": string(keyPEM)},
path := getDownstreamIssuerPathFromUpstreamRedirectURI(t, env.SupervisorTestUpstream.CallbackURL)
issuer := fmt.Sprintf("https://%s%s", addr, path)
_, _ = requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(
ctx,
t,
scheme,
addr,
caBundle,
issuer,
client,
) )
// Create HTTP client. // Create the downstream OIDCProvider and expect it to go into the success status condition.
httpClient := newHTTPClient(t, caBundle, nil) downstream := library.CreateTestOIDCProvider(ctx, t,
httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error { issuerURL.String(),
// Don't follow any redirects right now, since we simply want to validate that our auth endpoint certSecret.Name,
// redirects us. configv1alpha1.SuccessOIDCProviderStatusCondition,
return http.ErrUseLastResponse )
}
// Declare the downstream auth endpoint url we will use. // Create upstream OIDC provider and wait for it to become ready.
downstreamAuthURL := makeDownstreamAuthURL(t, scheme, addr, path) library.CreateTestUpstreamOIDCProvider(t, idpv1alpha1.UpstreamOIDCProviderSpec{
// Make request to auth endpoint - should fail, since we have no upstreams.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil)
require.NoError(t, err)
rsp, err := httpClient.Do(req)
require.NoError(t, err)
defer rsp.Body.Close()
require.Equal(t, http.StatusUnprocessableEntity, rsp.StatusCode)
// Create upstream OIDC provider.
spec := idpv1alpha1.UpstreamOIDCProviderSpec{
Issuer: env.SupervisorTestUpstream.Issuer, Issuer: env.SupervisorTestUpstream.Issuer,
TLS: &idpv1alpha1.TLSSpec{ TLS: &idpv1alpha1.TLSSpec{
CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)), CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)),
}, },
Client: idpv1alpha1.OIDCClient{ Client: idpv1alpha1.OIDCClient{
SecretName: makeTestClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name, SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name,
}, },
} }, idpv1alpha1.PhaseReady)
upstream := makeTestUpstream(t, spec, idpv1alpha1.PhaseReady)
// Make request to authorize endpoint - should pass, since we now have an upstream. // Perform OIDC discovery for our downstream.
req, err = http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil) var discovery *oidc.Provider
require.NoError(t, err) assert.Eventually(t, func() bool {
rsp, err = httpClient.Do(req) discovery, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer)
require.NoError(t, err) return err == nil
defer rsp.Body.Close() }, 60*time.Second, 1*time.Second)
require.Equal(t, http.StatusFound, rsp.StatusCode)
requireValidRedirectLocation(
ctx,
t,
upstream.Spec.Issuer,
env.SupervisorTestUpstream.ClientID,
env.SupervisorTestUpstream.CallbackURL,
rsp.Header.Get("Location"),
)
}
}
//nolint:unused
func getDownstreamIssuerPathFromUpstreamRedirectURI(t *testing.T, upstreamRedirectURI string) string {
// We need to construct the downstream issuer path from the upstream redirect URI since the two
// are related, and the upstream redirect URI is supplied via a static test environment
// variable. The upstream redirect URI should be something like
// https://supervisor.com/some/supervisor/path/callback
// and therefore the downstream issuer should be something like
// https://supervisor.com/some/supervisor/path
// since the /callback endpoint is placed at the root of the downstream issuer path.
upstreamRedirectURL, err := url.Parse(upstreamRedirectURI)
require.NoError(t, err) require.NoError(t, err)
redirectURIPathWithoutLastSegment, lastUpstreamRedirectURIPathSegment := path.Split(upstreamRedirectURL.Path) // Start a callback server on localhost.
require.Equalf( localCallbackServer := startLocalCallbackServer(t)
t,
"callback",
lastUpstreamRedirectURIPathSegment,
"expected upstream redirect URI (%q) to follow supervisor callback path conventions (i.e., end in /callback)",
upstreamRedirectURI,
)
if strings.HasSuffix(redirectURIPathWithoutLastSegment, "/") { // Form the OAuth2 configuration corresponding to our CLI client.
redirectURIPathWithoutLastSegment = redirectURIPathWithoutLastSegment[:len(redirectURIPathWithoutLastSegment)-1]
}
return redirectURIPathWithoutLastSegment
}
//nolint:unused
func makeDownstreamAuthURL(t *testing.T, scheme, addr, path string) string {
t.Helper()
downstreamOAuth2Config := oauth2.Config{ downstreamOAuth2Config := oauth2.Config{
// This is the hardcoded public client that the supervisor supports. // This is the hardcoded public client that the supervisor supports.
ClientID: "pinniped-cli", ClientID: "pinniped-cli",
Endpoint: oauth2.Endpoint{ Endpoint: discovery.Endpoint(),
AuthURL: fmt.Sprintf("%s://%s%s/oauth2/authorize", scheme, addr, path), RedirectURL: localCallbackServer.URL,
},
// This is the hardcoded downstream redirect URI that the supervisor supports.
RedirectURL: "http://127.0.0.1/callback",
Scopes: []string{"openid"}, Scopes: []string{"openid"},
} }
state, nonce, pkce := generateAuthRequestParams(t)
return downstreamOAuth2Config.AuthCodeURL( // Build a valid downstream authorize URL for the supervisor.
state.String(), stateParam, err := state.Generate()
nonce.Param(), require.NoError(t, err)
pkce.Challenge(), nonceParam, err := nonce.Generate()
pkce.Method(), require.NoError(t, err)
pkceParam, err := pkce.Generate()
require.NoError(t, err)
downstreamAuthorizeURL := downstreamOAuth2Config.AuthCodeURL(
stateParam.String(),
nonceParam.Param(),
pkceParam.Challenge(),
pkceParam.Method(),
) )
// Open the web browser and navigate to the downstream authorize URL.
page := browsertest.Open(t)
t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL))
require.NoError(t, page.Navigate(downstreamAuthorizeURL))
// Expect to be redirected to the upstream provider and log in.
browsertest.LoginToUpstream(t, page, env.SupervisorTestUpstream)
// Wait for the login to happen and us be redirected back to a localhost callback.
t.Logf("waiting for redirect to callback")
callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(localCallbackServer.URL) + `\?.+\z`)
browsertest.WaitForURL(t, page, callbackURLPattern)
// Expect that our callback handler was invoked.
callback := localCallbackServer.waitForCallback(10 * time.Second)
t.Logf("got callback request: %s", library.MaskTokens(callback.URL.String()))
require.Equal(t, stateParam.String(), callback.URL.Query().Get("state"))
require.Equal(t, "openid", callback.URL.Query().Get("scope"))
require.NotEmpty(t, callback.URL.Query().Get("code"))
} }
//nolint:unused func startLocalCallbackServer(t *testing.T) *localCallbackServer {
func generateAuthRequestParams(t *testing.T) (state.State, nonce.Nonce, pkce.Code) { // Handle the callback by sending the *http.Request object back through a channel.
t.Helper() callbacks := make(chan *http.Request, 1)
state, err := state.Generate() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, err) callbacks <- r
nonce, err := nonce.Generate() }))
require.NoError(t, err) server.URL += "/callback"
pkce, err := pkce.Generate() t.Cleanup(server.Close)
require.NoError(t, err) t.Cleanup(func() { close(callbacks) })
return state, nonce, pkce return &localCallbackServer{Server: server, t: t, callbacks: callbacks}
} }
//nolint:unused type localCallbackServer struct {
func requireValidRedirectLocation( *httptest.Server
ctx context.Context, t *testing.T
t *testing.T, callbacks <-chan *http.Request
issuer, clientID, redirectURI, actualLocation string,
) {
t.Helper()
env := library.IntegrationEnv(t)
// Do OIDC discovery on our test issuer to get auth endpoint.
transport := http.Transport{}
if env.Proxy != "" {
transport.Proxy = func(_ *http.Request) (*url.URL, error) {
return url.Parse(env.Proxy)
}
}
if env.SupervisorTestUpstream.CABundle != "" {
transport.TLSClientConfig = &tls.Config{RootCAs: x509.NewCertPool()}
transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(env.SupervisorTestUpstream.CABundle))
} }
ctx = oidc.ClientContext(ctx, &http.Client{Transport: &transport}) func (s *localCallbackServer) waitForCallback(timeout time.Duration) *http.Request {
upstreamProvider, err := oidc.NewProvider(ctx, issuer) select {
require.NoError(t, err) case callback := <-s.callbacks:
return callback
// Parse expected upstream auth URL. case <-time.After(timeout):
expectedLocationURL, err := url.Parse( require.Fail(s.t, "timed out waiting for callback request")
(&oauth2.Config{ return nil
ClientID: clientID, }
Endpoint: upstreamProvider.Endpoint(),
RedirectURL: redirectURI,
Scopes: []string{"openid"},
}).AuthCodeURL("", oauth2.AccessTypeOffline),
)
require.NoError(t, err)
// Parse actual upstream auth URL.
actualLocationURL, err := url.Parse(actualLocation)
require.NoError(t, err)
// First make some assertions on the query values. Note that we will not be able to know what
// certain query values are since they may be random (e.g., state, pkce, nonce).
expectedLocationQuery := expectedLocationURL.Query()
actualLocationQuery := actualLocationURL.Query()
require.NotEmpty(t, actualLocationQuery.Get("state"))
actualLocationQuery.Del("state")
require.NotEmpty(t, actualLocationQuery.Get("code_challenge"))
actualLocationQuery.Del("code_challenge")
require.NotEmpty(t, actualLocationQuery.Get("code_challenge_method"))
actualLocationQuery.Del("code_challenge_method")
require.NotEmpty(t, actualLocationQuery.Get("nonce"))
actualLocationQuery.Del("nonce")
require.Equal(t, expectedLocationQuery, actualLocationQuery)
// Zero-out query values, since we made specific assertions about those above, and assert that the
// URL's are equal otherwise.
expectedLocationURL.RawQuery = ""
actualLocationURL.RawQuery = ""
require.Equal(t, expectedLocationURL, actualLocationURL)
} }

View File

@ -4,13 +4,10 @@
package integration package integration
import ( import (
"context"
"encoding/base64" "encoding/base64"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1" "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1"
@ -28,7 +25,7 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) {
SecretName: "does-not-exist", SecretName: "does-not-exist",
}, },
} }
upstream := makeTestUpstream(t, spec, v1alpha1.PhaseError) upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseError)
expectUpstreamConditions(t, upstream, []v1alpha1.Condition{ expectUpstreamConditions(t, upstream, []v1alpha1.Condition{
{ {
Type: "ClientCredentialsValid", Type: "ClientCredentialsValid",
@ -56,10 +53,10 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) {
AdditionalScopes: []string{"email", "profile"}, AdditionalScopes: []string{"email", "profile"},
}, },
Client: v1alpha1.OIDCClient{ Client: v1alpha1.OIDCClient{
SecretName: makeTestClientCredsSecret(t, "test-client-id", "test-client-secret").Name, SecretName: library.CreateClientCredsSecret(t, "test-client-id", "test-client-secret").Name,
}, },
} }
upstream := makeTestUpstream(t, spec, v1alpha1.PhaseReady) upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseReady)
expectUpstreamConditions(t, upstream, []v1alpha1.Condition{ expectUpstreamConditions(t, upstream, []v1alpha1.Condition{
{ {
Type: "ClientCredentialsValid", Type: "ClientCredentialsValid",
@ -87,74 +84,3 @@ func expectUpstreamConditions(t *testing.T, upstream *v1alpha1.UpstreamOIDCProvi
} }
require.ElementsMatch(t, expected, normalized) require.ElementsMatch(t, expected, normalized)
} }
func makeTestClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret {
t.Helper()
env := library.IntegrationEnv(t)
client := library.NewClientset(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
created, err := client.CoreV1().Secrets(env.SupervisorNamespace).Create(ctx, &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Namespace: env.SupervisorNamespace,
GenerateName: "test-client-creds-",
Labels: map[string]string{"pinniped.dev/test": ""},
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
},
Type: "secrets.pinniped.dev/oidc-client",
StringData: map[string]string{
"clientID": clientID,
"clientSecret": clientSecret,
},
}, metav1.CreateOptions{})
require.NoError(t, err)
t.Cleanup(func() {
err := client.CoreV1().Secrets(env.SupervisorNamespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{})
require.NoError(t, err)
})
t.Logf("created test client credentials Secret %s", created.Name)
return created
}
func makeTestUpstream(t *testing.T, spec v1alpha1.UpstreamOIDCProviderSpec, expectedPhase v1alpha1.UpstreamOIDCProviderPhase) *v1alpha1.UpstreamOIDCProvider {
t.Helper()
env := library.IntegrationEnv(t)
client := library.NewSupervisorClientset(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Create the UpstreamOIDCProvider using GenerateName to get a random name.
created, err := client.IDPV1alpha1().
UpstreamOIDCProviders(env.SupervisorNamespace).
Create(ctx, &v1alpha1.UpstreamOIDCProvider{
ObjectMeta: metav1.ObjectMeta{
Namespace: env.SupervisorNamespace,
GenerateName: "test-upstream-",
Labels: map[string]string{"pinniped.dev/test": ""},
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
},
Spec: spec,
}, metav1.CreateOptions{})
require.NoError(t, err)
// Always clean this up after this point.
t.Cleanup(func() {
err := client.IDPV1alpha1().
UpstreamOIDCProviders(env.SupervisorNamespace).
Delete(context.Background(), created.Name, metav1.DeleteOptions{})
require.NoError(t, err)
})
t.Logf("created test UpstreamOIDCProvider %s", created.Name)
// Wait for the UpstreamOIDCProvider to enter the expected phase (or time out).
var result *v1alpha1.UpstreamOIDCProvider
require.Eventuallyf(t, func() bool {
var err error
result, err = client.IDPV1alpha1().
UpstreamOIDCProviders(created.Namespace).Get(ctx, created.Name, metav1.GetOptions{})
require.NoError(t, err)
return result.Status.Phase == expectedPhase
}, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase)
return result
}

View File

@ -0,0 +1,158 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package browsertest provides integration test helpers for our browser-based tests.
package browsertest
import (
"regexp"
"testing"
"time"
"github.com/sclevine/agouti"
"github.com/stretchr/testify/require"
"go.pinniped.dev/test/library"
)
const (
operationTimeout = 10 * time.Second
operationPollingInterval = 100 * time.Millisecond
)
// Open a webdriver-driven browser and returns an *agouti.Page to control it. The browser will be automatically
// closed at the end of the current test. It is configured for test purposes with the correct HTTP proxy and
// in a mode that ignore certificate errors.
func Open(t *testing.T) *agouti.Page {
t.Logf("opening browser driver")
env := library.IntegrationEnv(t)
caps := agouti.NewCapabilities()
if env.Proxy != "" {
t.Logf("configuring Chrome to use proxy %q", env.Proxy)
caps = caps.Proxy(agouti.ProxyConfig{
ProxyType: "manual",
HTTPProxy: env.Proxy,
SSLProxy: env.Proxy,
NoProxy: "127.0.0.1",
})
}
agoutiDriver := agouti.ChromeDriver(
agouti.Desired(caps),
agouti.ChromeOptions("args", []string{
"--no-sandbox",
"--ignore-certificate-errors",
"--headless", // Comment out this line to see the tests happen in a visible browser window.
}),
// Uncomment this to see stdout/stderr from chromedriver.
// agouti.Debug,
)
require.NoError(t, agoutiDriver.Start())
t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) })
page, err := agoutiDriver.NewPage(agouti.Browser("chrome"))
require.NoError(t, err)
require.NoError(t, page.Reset())
return page
}
// WaitForVisibleElements expects the page to contain all the the elements specified by the selectors. It waits for this
// to occur and times out, failing the test, if they never appear.
func WaitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) {
t.Helper()
require.Eventuallyf(t,
func() bool {
for _, sel := range selectors {
vis, err := page.First(sel).Visible()
if !(err == nil && vis) {
return false
}
}
return true
},
operationTimeout,
operationPollingInterval,
"expected to have a page with selectors %v, but it never loaded",
selectors,
)
}
// WaitForURL expects the page to eventually navigate to a URL matching the specified pattern. It waits for this
// to occur and times out, failing the test, if it never does.
func WaitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) {
var lastURL string
require.Eventuallyf(t,
func() bool {
url, err := page.URL()
if err == nil && pat.MatchString(url) {
return true
}
if url != lastURL {
t.Logf("saw URL %s", url)
lastURL = url
}
return false
},
operationTimeout,
operationPollingInterval,
"expected to browse to %s, but never got there",
pat,
)
}
// LoginToUpstream expects the page to be redirected to one of several known upstream IDPs.
// It knows how to enter the test username/password and submit the upstream login form.
func LoginToUpstream(t *testing.T, page *agouti.Page, upstream library.TestOIDCUpstream) {
t.Helper()
type config struct {
Name string
IssuerPattern *regexp.Regexp
LoginPagePattern *regexp.Regexp
UsernameSelector string
PasswordSelector string
LoginButtonSelector string
}
// Lookup the provider by matching on the issuer URL.
var cfg *config
for _, p := range []*config{
{
Name: "Okta",
IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
UsernameSelector: "input#okta-signin-username",
PasswordSelector: "input#okta-signin-password",
LoginButtonSelector: "input#okta-signin-submit",
},
{
Name: "Dex",
IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`),
LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`),
UsernameSelector: "input#login",
PasswordSelector: "input#password",
LoginButtonSelector: "button#submit-login",
},
} {
if p.IssuerPattern.MatchString(upstream.Issuer) {
cfg = p
break
}
}
if cfg == nil {
require.Failf(t, "could not find login provider for issuer %q", upstream.Issuer)
return
}
// Expect to be redirected to the login page.
t.Logf("waiting for redirect to %s login page", cfg.Name)
WaitForURL(t, page, cfg.LoginPagePattern)
// Wait for the login page to be rendered.
WaitForVisibleElements(t, page, cfg.UsernameSelector, cfg.PasswordSelector, cfg.LoginButtonSelector)
// Fill in the username and password and click "submit".
t.Logf("logging into %s", cfg.Name)
require.NoError(t, page.First(cfg.UsernameSelector).Fill(upstream.Username))
require.NoError(t, page.First(cfg.PasswordSelector).Fill(upstream.Password))
require.NoError(t, page.First(cfg.LoginButtonSelector).Click())
}

View File

@ -25,6 +25,7 @@ import (
auth1alpha1 "go.pinniped.dev/generated/1.19/apis/concierge/authentication/v1alpha1" auth1alpha1 "go.pinniped.dev/generated/1.19/apis/concierge/authentication/v1alpha1"
configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1" configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1"
idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1"
conciergeclientset "go.pinniped.dev/generated/1.19/client/concierge/clientset/versioned" conciergeclientset "go.pinniped.dev/generated/1.19/client/concierge/clientset/versioned"
supervisorclientset "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned" supervisorclientset "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned"
@ -140,11 +141,7 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty
defer cancel() defer cancel()
webhook, err := webhooks.Create(createContext, &auth1alpha1.WebhookAuthenticator{ webhook, err := webhooks.Create(createContext, &auth1alpha1.WebhookAuthenticator{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: testObjectMeta(t, "webhook"),
GenerateName: "test-webhook-",
Labels: map[string]string{"pinniped.dev/test": ""},
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
},
Spec: testEnv.TestWebhook, Spec: testEnv.TestWebhook,
}, metav1.CreateOptions{}) }, metav1.CreateOptions{})
require.NoError(t, err, "could not create test WebhookAuthenticator") require.NoError(t, err, "could not create test WebhookAuthenticator")
@ -172,7 +169,7 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty
// //
// If the provided issuer is not the empty string, then it will be used for the // If the provided issuer is not the empty string, then it will be used for the
// OIDCProvider.Spec.Issuer field. Else, a random issuer will be generated. // OIDCProvider.Spec.Issuer field. Else, a random issuer will be generated.
func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecretName string) *configv1alpha1.OIDCProvider { func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer string, certSecretName string, expectStatus configv1alpha1.OIDCProviderStatusCondition) *configv1alpha1.OIDCProvider {
t.Helper() t.Helper()
testEnv := IntegrationEnv(t) testEnv := IntegrationEnv(t)
@ -180,18 +177,12 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre
defer cancel() defer cancel()
if issuer == "" { if issuer == "" {
var err error issuer = randomIssuer(t)
issuer, err = randomIssuer()
require.NoError(t, err)
} }
opcs := NewSupervisorClientset(t).ConfigV1alpha1().OIDCProviders(testEnv.SupervisorNamespace) opcs := NewSupervisorClientset(t).ConfigV1alpha1().OIDCProviders(testEnv.SupervisorNamespace)
opc, err := opcs.Create(createContext, &configv1alpha1.OIDCProvider{ opc, err := opcs.Create(createContext, &configv1alpha1.OIDCProvider{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: testObjectMeta(t, "oidc-provider"),
GenerateName: "test-oidc-provider-",
Labels: map[string]string{"pinniped.dev/test": ""},
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
},
Spec: configv1alpha1.OIDCProviderSpec{ Spec: configv1alpha1.OIDCProviderSpec{
Issuer: issuer, Issuer: issuer,
TLS: &configv1alpha1.OIDCProviderTLSSpec{SecretName: certSecretName}, TLS: &configv1alpha1.OIDCProviderTLSSpec{SecretName: certSecretName},
@ -213,13 +204,103 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre
} }
}) })
// If we're not expecting any particular status, just return the new OIDCProvider immediately.
if expectStatus == "" {
return opc return opc
} }
func randomIssuer() (string, error) { // Wait for the OIDCProvider to enter the expected phase (or time out).
var result *configv1alpha1.OIDCProvider
require.Eventuallyf(t, func() bool {
var err error
result, err = opcs.Get(ctx, opc.Name, metav1.GetOptions{})
require.NoError(t, err)
return result.Status.Status == expectStatus
}, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectStatus)
return opc
}
func randomIssuer(t *testing.T) string {
var buf [8]byte var buf [8]byte
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil { _, err := io.ReadFull(rand.Reader, buf[:])
return "", fmt.Errorf("could not generate random state: %w", err) require.NoError(t, err)
return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:]))
}
func CreateTestSecret(t *testing.T, namespace string, baseName string, secretType string, stringData map[string]string) *corev1.Secret {
t.Helper()
client := NewClientset(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
created, err := client.CoreV1().Secrets(namespace).Create(ctx, &corev1.Secret{
ObjectMeta: testObjectMeta(t, baseName),
Type: corev1.SecretType(secretType),
StringData: stringData,
}, metav1.CreateOptions{})
require.NoError(t, err)
t.Cleanup(func() {
err := client.CoreV1().Secrets(namespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{})
require.NoError(t, err)
})
t.Logf("created test Secret %s", created.Name)
return created
}
func CreateClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret {
t.Helper()
env := IntegrationEnv(t)
return CreateTestSecret(t,
env.SupervisorNamespace,
"test-client-creds-",
"secrets.pinniped.dev/oidc-client",
map[string]string{
"clientID": clientID,
"clientSecret": clientSecret,
},
)
}
func CreateTestUpstreamOIDCProvider(t *testing.T, spec idpv1alpha1.UpstreamOIDCProviderSpec, expectedPhase idpv1alpha1.UpstreamOIDCProviderPhase) *idpv1alpha1.UpstreamOIDCProvider {
t.Helper()
env := IntegrationEnv(t)
client := NewSupervisorClientset(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Create the UpstreamOIDCProvider using GenerateName to get a random name.
upstreams := client.IDPV1alpha1().UpstreamOIDCProviders(env.SupervisorNamespace)
created, err := upstreams.Create(ctx, &idpv1alpha1.UpstreamOIDCProvider{
ObjectMeta: testObjectMeta(t, "upstream"),
Spec: spec,
}, metav1.CreateOptions{})
require.NoError(t, err)
// Always clean this up after this point.
t.Cleanup(func() {
err := upstreams.Delete(context.Background(), created.Name, metav1.DeleteOptions{})
require.NoError(t, err)
})
t.Logf("created test UpstreamOIDCProvider %s", created.Name)
// Wait for the UpstreamOIDCProvider to enter the expected phase (or time out).
var result *idpv1alpha1.UpstreamOIDCProvider
require.Eventuallyf(t, func() bool {
var err error
result, err = upstreams.Get(ctx, created.Name, metav1.GetOptions{})
require.NoError(t, err)
return result.Status.Phase == expectedPhase
}, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase)
return result
}
func testObjectMeta(t *testing.T, baseName string) metav1.ObjectMeta {
return metav1.ObjectMeta{
GenerateName: fmt.Sprintf("test-%s-", baseName),
Labels: map[string]string{"pinniped.dev/test": ""},
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
} }
return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:])), nil
} }

49
test/library/dumplogs.go Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package library
import (
"bufio"
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// DumpLogs is meant to be called in a `defer` to dump the logs of components in the cluster on a test failure.
func DumpLogs(t *testing.T, namespace string) {
// Only trigger on failed tests.
if !t.Failed() {
return
}
kubeClient := NewClientset(t)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
logTailLines := int64(40)
pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{})
require.NoError(t, err)
for _, pod := range pods.Items {
for _, container := range pod.Status.ContainerStatuses {
t.Logf("pod %s/%s container %s restarted %d times:", pod.Namespace, pod.Name, container.Name, container.RestartCount)
req := kubeClient.CoreV1().Pods(namespace).GetLogs(pod.Name, &corev1.PodLogOptions{
Container: container.Name,
TailLines: &logTailLines,
})
logReader, err := req.Stream(ctx)
require.NoError(t, err)
scanner := bufio.NewScanner(logReader)
for scanner.Scan() {
t.Logf("%s/%s/%s > %s", pod.Namespace, pod.Name, container.Name, scanner.Text())
}
require.NoError(t, scanner.Err())
}
}
}

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"regexp" "regexp"
"strings"
"testing" "testing"
) )
@ -26,18 +27,22 @@ func (l *testlogReader) Read(p []byte) (n int, err error) {
l.t.Helper() l.t.Helper()
n, err = l.r.Read(p) n, err = l.r.Read(p)
if err != nil { if err != nil {
l.t.Logf("%s > %q: %v", l.name, maskTokens(p[0:n]), err) l.t.Logf("%s > %q: %v", l.name, MaskTokens(string(p[0:n])), err)
} else { } else {
l.t.Logf("%s > %q", l.name, maskTokens(p[0:n])) l.t.Logf("%s > %q", l.name, MaskTokens(string(p[0:n])))
} }
return return
} }
//nolint: gochecknoglobals // MaskTokens makes a best-effort attempt to mask out things that look like secret tokens in test output.
// The goal is more to have readable test output than for any security reason.
func MaskTokens(in string) string {
var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`) var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`)
return tokenLike.ReplaceAllStringFunc(in, func(t string) string {
func maskTokens(in []byte) string { // This is a silly heuristic, but things with multiple dots are more likely hostnames that we don't want masked.
return tokenLike.ReplaceAllStringFunc(string(in), func(t string) string { if strings.Count(t, ".") >= 4 {
return t
}
return fmt.Sprintf("[...%d bytes...]", len(t)) return fmt.Sprintf("[...%d bytes...]", len(t))
}) })
} }

16
test/library/iplookup.go Normal file
View File

@ -0,0 +1,16 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// +build !go1.14
package library
import (
"context"
"net"
)
// LookupIP looks up the IP address of the provided hostname, preferring IPv4.
func LookupIP(ctx context.Context, hostname string) ([]net.IP, error) {
return net.DefaultResolver.LookupIP(ctx, "ip4", hostname)
}

View File

@ -0,0 +1,28 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// +build go1.14
package library
import (
"context"
"net"
)
// LookupIP looks up the IP address of the provided hostname, preferring IPv4.
func LookupIP(ctx context.Context, hostname string) ([]net.IP, error) {
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
if err != nil {
return nil, err
}
// Filter out to only IPv4 addresses
var results []net.IP
for _, ip := range ips {
if ip.IP.To4() != nil {
results = append(results, ip.IP)
}
}
return results, nil
}