Merge remote-tracking branch 'upstream/callback-endpoint' into token-endpoint
This commit is contained in:
commit
2f1a67ef0d
@ -196,7 +196,12 @@ func run(serverInstallationNamespace string, cfg *supervisor.Config) error {
|
||||
dynamicUpstreamIDPProvider := provider.NewDynamicUpstreamIDPProvider()
|
||||
|
||||
// OIDC endpoints will be served by the oidProvidersManager, and any non-OIDC paths will fallback to the healthMux.
|
||||
oidProvidersManager := manager.NewManager(healthMux, dynamicJWKSProvider, dynamicUpstreamIDPProvider)
|
||||
oidProvidersManager := manager.NewManager(
|
||||
healthMux,
|
||||
dynamicJWKSProvider,
|
||||
dynamicUpstreamIDPProvider,
|
||||
kubeClient.CoreV1().Secrets(serverInstallationNamespace),
|
||||
)
|
||||
|
||||
startControllers(
|
||||
ctx,
|
||||
|
@ -20,6 +20,7 @@ import (
|
||||
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/filesession"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
//nolint: gochecknoinits
|
||||
@ -27,7 +28,7 @@ func init() {
|
||||
loginCmd.AddCommand(oidcLoginCommand(oidcclient.Login))
|
||||
}
|
||||
|
||||
func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error)) *cobra.Command {
|
||||
func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error)) *cobra.Command {
|
||||
var (
|
||||
cmd = cobra.Command{
|
||||
Args: cobra.NoArgs,
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"go.pinniped.dev/internal/here"
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
func TestLoginOIDCCommand(t *testing.T) {
|
||||
@ -92,12 +93,12 @@ func TestLoginOIDCCommand(t *testing.T) {
|
||||
gotClientID string
|
||||
gotOptions []oidcclient.Option
|
||||
)
|
||||
cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error) {
|
||||
cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidctypes.Token, error) {
|
||||
gotIssuer = issuer
|
||||
gotClientID = clientID
|
||||
gotOptions = opts
|
||||
return &oidcclient.Token{
|
||||
IDToken: &oidcclient.IDToken{
|
||||
return &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(time1),
|
||||
},
|
||||
|
1
go.mod
1
go.mod
@ -18,6 +18,7 @@ require (
|
||||
github.com/gorilla/securecookie v1.1.1
|
||||
github.com/ory/fosite v0.35.1
|
||||
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/sclevine/agouti v3.0.0+incompatible
|
||||
github.com/sclevine/spec v1.4.0
|
||||
github.com/spf13/cobra v1.0.0
|
||||
|
@ -103,6 +103,7 @@ k8s_yaml(local([
|
||||
'--data-value-yaml', 'service_http_nodeport_nodeport=31234',
|
||||
'--data-value-yaml', 'service_https_nodeport_port=443',
|
||||
'--data-value-yaml', 'service_https_nodeport_nodeport=31243',
|
||||
'--data-value-yaml', 'service_https_clusterip_port=443',
|
||||
'--data-value-yaml', 'custom_labels={mySupervisorCustomLabelName: mySupervisorCustomLabelValue}',
|
||||
]))
|
||||
# Tell tilt to watch all of those files for changes.
|
||||
|
@ -230,6 +230,7 @@ if ! tilt_mode; then
|
||||
--data-value-yaml 'service_http_nodeport_nodeport=31234' \
|
||||
--data-value-yaml 'service_https_nodeport_port=443' \
|
||||
--data-value-yaml 'service_https_nodeport_nodeport=31243' \
|
||||
--data-value-yaml 'service_https_clusterip_port=443' \
|
||||
>"$manifest"
|
||||
|
||||
kapp deploy --yes --app "$supervisor_app_name" --diff-changes --file "$manifest"
|
||||
@ -302,7 +303,7 @@ export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER=https://dex.dex.svc.cluster
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_ISSUER_CA_BUNDLE="${test_ca_bundle_pem}"
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_ID=pinniped-supervisor
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CLIENT_SECRET=pinniped-supervisor-secret
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://127.0.0.1:12345/some/path/callback
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_CALLBACK_URL=https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME=pinny@example.com
|
||||
export PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD=password
|
||||
|
||||
|
@ -136,6 +136,13 @@ func (c *CA) Bundle() []byte {
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: c.caCertBytes})
|
||||
}
|
||||
|
||||
// Pool returns the current CA signing bundle as a *x509.CertPool.
|
||||
func (c *CA) Pool() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
pool.AppendCertsFromPEM(c.Bundle())
|
||||
return pool
|
||||
}
|
||||
|
||||
// Issue a new server certificate for the given identity and duration.
|
||||
func (c *CA) Issue(subject pkix.Name, dnsNames []string, ips []net.IP, ttl time.Duration) (*tls.Certificate, error) {
|
||||
// Choose a random 128 bit serial number.
|
||||
|
@ -182,6 +182,16 @@ func TestBundle(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPool(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
ca, err := New(pkix.Name{CommonName: "test"}, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := ca.Pool()
|
||||
require.Len(t, got.Subjects(), 1)
|
||||
})
|
||||
}
|
||||
|
||||
type errSigner struct {
|
||||
pubkey crypto.PublicKey
|
||||
err error
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/go-logr/logr"
|
||||
"golang.org/x/oauth2"
|
||||
"k8s.io/apimachinery/pkg/api/equality"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/labels"
|
||||
@ -30,6 +31,7 @@ import (
|
||||
pinnipedcontroller "go.pinniped.dev/internal/controller"
|
||||
"go.pinniped.dev/internal/controllerlib"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/upstreamoidc"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -62,21 +64,27 @@ const (
|
||||
|
||||
// IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations.
|
||||
type IDPCache interface {
|
||||
SetIDPList([]provider.UpstreamOIDCIdentityProvider)
|
||||
SetIDPList([]provider.UpstreamOIDCIdentityProviderI)
|
||||
}
|
||||
|
||||
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
|
||||
type lruValidatorCache struct{ cache *cache.Expiring }
|
||||
|
||||
func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider {
|
||||
if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
|
||||
return result.(*oidc.Provider)
|
||||
}
|
||||
return nil
|
||||
type lruValidatorCacheEntry struct {
|
||||
provider *oidc.Provider
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider) {
|
||||
c.cache.Set(c.cacheKey(spec), provider, validatorCacheTTL)
|
||||
func (c *lruValidatorCache) getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client) {
|
||||
if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
|
||||
entry := result.(*lruValidatorCacheEntry)
|
||||
return entry.provider, entry.client
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *lruValidatorCache) putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider, client *http.Client) {
|
||||
c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, validatorCacheTTL)
|
||||
}
|
||||
|
||||
func (c *lruValidatorCache) cacheKey(spec *v1alpha1.UpstreamOIDCProviderSpec) interface{} {
|
||||
@ -95,8 +103,8 @@ type controller struct {
|
||||
providers idpinformers.UpstreamOIDCProviderInformer
|
||||
secrets corev1informers.SecretInformer
|
||||
validatorCache interface {
|
||||
getProvider(spec *v1alpha1.UpstreamOIDCProviderSpec) *oidc.Provider
|
||||
putProvider(spec *v1alpha1.UpstreamOIDCProviderSpec, provider *oidc.Provider)
|
||||
getProvider(*v1alpha1.UpstreamOIDCProviderSpec) (*oidc.Provider, *http.Client)
|
||||
putProvider(*v1alpha1.UpstreamOIDCProviderSpec, *oidc.Provider, *http.Client)
|
||||
}
|
||||
}
|
||||
|
||||
@ -132,13 +140,13 @@ func (c *controller) Sync(ctx controllerlib.Context) error {
|
||||
}
|
||||
|
||||
requeue := false
|
||||
validatedUpstreams := make([]provider.UpstreamOIDCIdentityProvider, 0, len(actualUpstreams))
|
||||
validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams))
|
||||
for _, upstream := range actualUpstreams {
|
||||
valid := c.validateUpstream(ctx, upstream)
|
||||
if valid == nil {
|
||||
requeue = true
|
||||
} else {
|
||||
validatedUpstreams = append(validatedUpstreams, *valid)
|
||||
validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid))
|
||||
}
|
||||
}
|
||||
c.cache.SetIDPList(validatedUpstreams)
|
||||
@ -150,10 +158,14 @@ func (c *controller) Sync(ctx controllerlib.Context) error {
|
||||
|
||||
// validateUpstream validates the provided v1alpha1.UpstreamOIDCProvider and returns the validated configuration as a
|
||||
// provider.UpstreamOIDCIdentityProvider. As a side effect, it also updates the status of the v1alpha1.UpstreamOIDCProvider.
|
||||
func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *provider.UpstreamOIDCIdentityProvider {
|
||||
result := provider.UpstreamOIDCIdentityProvider{
|
||||
func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.UpstreamOIDCProvider) *upstreamoidc.ProviderConfig {
|
||||
result := upstreamoidc.ProviderConfig{
|
||||
Name: upstream.Name,
|
||||
Config: &oauth2.Config{
|
||||
Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes),
|
||||
},
|
||||
UsernameClaim: upstream.Spec.Claims.Username,
|
||||
GroupsClaim: upstream.Spec.Claims.Groups,
|
||||
}
|
||||
conditions := []*v1alpha1.Condition{
|
||||
c.validateSecret(upstream, &result),
|
||||
@ -180,7 +192,7 @@ func (c *controller) validateUpstream(ctx controllerlib.Context, upstream *v1alp
|
||||
}
|
||||
|
||||
// validateSecret validates the .spec.client.secretName field and returns the appropriate ClientCredentialsValid condition.
|
||||
func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition {
|
||||
func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
|
||||
secretName := upstream.Spec.Client.SecretName
|
||||
|
||||
// Fetch the Secret from informer cache.
|
||||
@ -217,7 +229,8 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
|
||||
}
|
||||
|
||||
// If everything is valid, update the result and set the condition to true.
|
||||
result.ClientID = string(clientID)
|
||||
result.Config.ClientID = string(clientID)
|
||||
result.Config.ClientSecret = string(clientSecret)
|
||||
return &v1alpha1.Condition{
|
||||
Type: typeClientCredsValid,
|
||||
Status: v1alpha1.ConditionTrue,
|
||||
@ -227,9 +240,9 @@ func (c *controller) validateSecret(upstream *v1alpha1.UpstreamOIDCProvider, res
|
||||
}
|
||||
|
||||
// validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition.
|
||||
func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *provider.UpstreamOIDCIdentityProvider) *v1alpha1.Condition {
|
||||
// Get the provider (from cache if possible).
|
||||
discoveredProvider := c.validatorCache.getProvider(&upstream.Spec)
|
||||
func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.UpstreamOIDCProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
|
||||
// Get the provider and HTTP Client from cache if possible.
|
||||
discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec)
|
||||
|
||||
// If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache.
|
||||
if discoveredProvider == nil {
|
||||
@ -242,7 +255,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
|
||||
httpClient = &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConfig}}
|
||||
|
||||
discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer)
|
||||
if err != nil {
|
||||
@ -255,7 +268,7 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
|
||||
}
|
||||
|
||||
// Update the cache with the newly discovered value.
|
||||
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider)
|
||||
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient)
|
||||
}
|
||||
|
||||
// Parse out and validate the discovered authorize endpoint.
|
||||
@ -278,7 +291,9 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
|
||||
}
|
||||
|
||||
// If everything is valid, update the result and set the condition to true.
|
||||
result.AuthorizationURL = *authURL
|
||||
result.Config.Endpoint = discoveredProvider.Endpoint()
|
||||
result.Provider = discoveredProvider
|
||||
result.Client = httpClient
|
||||
return &v1alpha1.Condition{
|
||||
Type: typeOIDCDiscoverySucceeded,
|
||||
Status: v1alpha1.ConditionTrue,
|
||||
|
@ -24,9 +24,11 @@ import (
|
||||
pinnipedfake "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned/fake"
|
||||
pinnipedinformers "go.pinniped.dev/generated/1.19/client/supervisor/informers/externalversions"
|
||||
"go.pinniped.dev/internal/controllerlib"
|
||||
"go.pinniped.dev/internal/oidc/oidctestutil"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/internal/testutil/testlogger"
|
||||
"go.pinniped.dev/internal/upstreamoidc"
|
||||
)
|
||||
|
||||
func TestController(t *testing.T) {
|
||||
@ -49,6 +51,8 @@ func TestController(t *testing.T) {
|
||||
testClientID = "test-oidc-client-id"
|
||||
testClientSecret = "test-oidc-client-secret"
|
||||
testValidSecretData = map[string][]byte{"clientID": []byte(testClientID), "clientSecret": []byte(testClientSecret)}
|
||||
testGroupsClaim = "test-groups-claim"
|
||||
testUsernameClaim = "test-username-claim"
|
||||
)
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -56,7 +60,7 @@ func TestController(t *testing.T) {
|
||||
inputSecrets []runtime.Object
|
||||
wantErr string
|
||||
wantLogs []string
|
||||
wantResultingCache []provider.UpstreamOIDCIdentityProvider
|
||||
wantResultingCache []provider.UpstreamOIDCIdentityProviderI
|
||||
wantResultingUpstreams []v1alpha1.UpstreamOIDCProvider
|
||||
}{
|
||||
{
|
||||
@ -80,7 +84,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="secret \"test-client-secret\" not found" "name"="test-name" "namespace"="test-namespace" "reason"="SecretNotFound" "type"="ClientCredentialsValid"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -126,7 +130,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" has wrong type \"some-other-type\" (should be \"secrets.pinniped.dev/oidc-client\")" "name"="test-name" "namespace"="test-namespace" "reason"="SecretWrongType" "type"="ClientCredentialsValid"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -171,7 +175,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="referenced Secret \"test-client-secret\" is missing required keys [\"clientID\" \"clientSecret\"]" "name"="test-name" "namespace"="test-namespace" "reason"="SecretMissingKeys" "type"="ClientCredentialsValid"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -219,7 +223,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: illegal base64 data at input byte 7" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -267,7 +271,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="spec.certificateAuthorityData is invalid: no certificates found" "reason"="InvalidTLSConfig" "status"="False" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="spec.certificateAuthorityData is invalid: no certificates found" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidTLSConfig" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -312,7 +316,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to perform OIDC discovery against \"invalid-url\"" "reason"="Unreachable" "status"="False" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to perform OIDC discovery against \"invalid-url\"" "name"="test-name" "namespace"="test-namespace" "reason"="Unreachable" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -358,7 +362,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="failed to parse authorization endpoint URL: parse \"%\": invalid URL escape \"%\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -404,7 +408,7 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "reason"="InvalidResponse" "status"="False" "type"="OIDCDiscoverySucceeded"`,
|
||||
`upstream-observer "error"="UpstreamOIDCProvider has a failing condition" "msg"="found failing condition" "message"="authorization endpoint URL scheme must be \"https\", not \"http\"" "name"="test-name" "namespace"="test-namespace" "reason"="InvalidResponse" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -437,6 +441,7 @@ func TestController(t *testing.T) {
|
||||
TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64},
|
||||
Client: v1alpha1.OIDCClient{SecretName: testSecretName},
|
||||
AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: append(testAdditionalScopes, "xyz", "openid")},
|
||||
Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim},
|
||||
},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
Phase: "Error",
|
||||
@ -455,12 +460,16 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`,
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{
|
||||
&oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: testName,
|
||||
ClientID: testClientID,
|
||||
AuthorizationURL: *testIssuerAuthorizeURL,
|
||||
Scopes: append(testExpectedScopes, "xyz"),
|
||||
}},
|
||||
UsernameClaim: testUsernameClaim,
|
||||
GroupsClaim: testGroupsClaim,
|
||||
},
|
||||
},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -481,6 +490,7 @@ func TestController(t *testing.T) {
|
||||
TLS: &v1alpha1.TLSSpec{CertificateAuthorityData: testIssuerCABase64},
|
||||
Client: v1alpha1.OIDCClient{SecretName: testSecretName},
|
||||
AuthorizationConfig: v1alpha1.OIDCAuthorizationConfig{AdditionalScopes: testAdditionalScopes},
|
||||
Claims: v1alpha1.OIDCClaims{Groups: testGroupsClaim, Username: testUsernameClaim},
|
||||
},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
Phase: "Ready",
|
||||
@ -499,12 +509,16 @@ func TestController(t *testing.T) {
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="loaded client credentials" "reason"="Success" "status"="True" "type"="ClientCredentialsValid"`,
|
||||
`upstream-observer "level"=0 "msg"="updated condition" "name"="test-name" "namespace"="test-namespace" "message"="discovered issuer configuration" "reason"="Success" "status"="True" "type"="OIDCDiscoverySucceeded"`,
|
||||
},
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProvider{{
|
||||
wantResultingCache: []provider.UpstreamOIDCIdentityProviderI{
|
||||
&oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: testName,
|
||||
ClientID: testClientID,
|
||||
AuthorizationURL: *testIssuerAuthorizeURL,
|
||||
Scopes: testExpectedScopes,
|
||||
}},
|
||||
UsernameClaim: testUsernameClaim,
|
||||
GroupsClaim: testGroupsClaim,
|
||||
},
|
||||
},
|
||||
wantResultingUpstreams: []v1alpha1.UpstreamOIDCProvider{{
|
||||
ObjectMeta: metav1.ObjectMeta{Namespace: testNamespace, Name: testName, Generation: 1234},
|
||||
Status: v1alpha1.UpstreamOIDCProviderStatus{
|
||||
@ -527,7 +541,9 @@ func TestController(t *testing.T) {
|
||||
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
|
||||
testLog := testlogger.New(t)
|
||||
cache := provider.NewDynamicUpstreamIDPProvider()
|
||||
cache.SetIDPList([]provider.UpstreamOIDCIdentityProvider{{Name: "initial-entry"}})
|
||||
cache.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{
|
||||
&upstreamoidc.ProviderConfig{Name: "initial-entry"},
|
||||
})
|
||||
|
||||
controller := New(
|
||||
cache,
|
||||
@ -551,7 +567,18 @@ func TestController(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n"))
|
||||
require.ElementsMatch(t, tt.wantResultingCache, cache.GetIDPList())
|
||||
|
||||
actualIDPList := cache.GetIDPList()
|
||||
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
|
||||
for i := range actualIDPList {
|
||||
actualIDP := actualIDPList[i].(*upstreamoidc.ProviderConfig)
|
||||
require.Equal(t, tt.wantResultingCache[i].GetName(), actualIDP.GetName())
|
||||
require.Equal(t, tt.wantResultingCache[i].GetClientID(), actualIDP.GetClientID())
|
||||
require.Equal(t, tt.wantResultingCache[i].GetAuthorizationURL().String(), actualIDP.GetAuthorizationURL().String())
|
||||
require.Equal(t, tt.wantResultingCache[i].GetUsernameClaim(), actualIDP.GetUsernameClaim())
|
||||
require.Equal(t, tt.wantResultingCache[i].GetGroupsClaim(), actualIDP.GetGroupsClaim())
|
||||
require.ElementsMatch(t, tt.wantResultingCache[i].GetScopes(), actualIDP.GetScopes())
|
||||
}
|
||||
|
||||
actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
|
@ -30,7 +30,7 @@ const (
|
||||
|
||||
ErrSecretTypeMismatch = constable.Error("secret storage data has incorrect type")
|
||||
ErrSecretLabelMismatch = constable.Error("secret storage data has incorrect label")
|
||||
ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version") // TODO do we need this?
|
||||
ErrSecretVersionMismatch = constable.Error("secret storage data has incorrect version")
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
@ -139,7 +139,7 @@ func (s *secretsStorage) toSecret(signature, resourceVersion string, data JSON)
|
||||
Labels: map[string]string{
|
||||
secretLabelKey: s.resource, // make it easier to find this stuff via kubectl
|
||||
},
|
||||
OwnerReferences: nil, // TODO we should set this to make sure stuff gets clean up
|
||||
OwnerReferences: nil,
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
secretDataKey: buf,
|
||||
|
@ -62,17 +62,17 @@ func TestStorage(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "get non-existent",
|
||||
resource: "authorization-codes",
|
||||
resource: "authcode",
|
||||
mocks: nil,
|
||||
run: func(t *testing.T, storage Storage) error {
|
||||
_, err := storage.Get(ctx, "not-exists", nil)
|
||||
return err
|
||||
},
|
||||
wantActions: []coretesting.Action{
|
||||
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-t2fx46yyvs3a"),
|
||||
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-t2fx46yyvs3a"),
|
||||
},
|
||||
wantSecrets: nil,
|
||||
wantErr: `failed to get authorization-codes for signature not-exists: secrets "pinniped-storage-authorization-codes-t2fx46yyvs3a" not found`,
|
||||
wantErr: `failed to get authcode for signature not-exists: secrets "pinniped-storage-authcode-t2fx46yyvs3a" not found`,
|
||||
},
|
||||
{
|
||||
name: "delete non-existent",
|
||||
|
@ -1,334 +0,0 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authorizationcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
fuzz "github.com/google/gofuzz"
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/oauth2"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
coretesting "k8s.io/client-go/testing"
|
||||
)
|
||||
|
||||
func TestAuthorizeCodeStorage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
secretsGVR := schema.GroupVersionResource{
|
||||
Group: "",
|
||||
Version: "v1",
|
||||
Resource: "secrets",
|
||||
}
|
||||
|
||||
const namespace = "test-ns"
|
||||
|
||||
type mocker interface {
|
||||
AddReactor(verb, resource string, reaction coretesting.ReactionFunc)
|
||||
PrependReactor(verb, resource string, reaction coretesting.ReactionFunc)
|
||||
Tracker() coretesting.ObjectTracker
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mocks func(*testing.T, mocker)
|
||||
run func(*testing.T, oauth2.AuthorizeCodeStorage) error
|
||||
wantActions []coretesting.Action
|
||||
wantSecrets []corev1.Secret
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "create, get, invalidate standard flow",
|
||||
mocks: nil,
|
||||
run: func(t *testing.T, storage oauth2.AuthorizeCodeStorage) error {
|
||||
request := &fosite.AuthorizeRequest{
|
||||
ResponseTypes: fosite.Arguments{"not-code"},
|
||||
RedirectURI: &url.URL{
|
||||
Scheme: "",
|
||||
Opaque: "weee",
|
||||
User: &url.Userinfo{},
|
||||
Host: "",
|
||||
Path: "/callback",
|
||||
RawPath: "",
|
||||
ForceQuery: false,
|
||||
RawQuery: "",
|
||||
Fragment: "",
|
||||
},
|
||||
State: "stated",
|
||||
HandledResponseTypes: fosite.Arguments{"not-type"},
|
||||
Request: fosite.Request{
|
||||
ID: "abcd-1",
|
||||
RequestedAt: time.Time{},
|
||||
Client: &fosite.DefaultOpenIDConnectClient{
|
||||
DefaultClient: &fosite.DefaultClient{
|
||||
ID: "pinny",
|
||||
Secret: nil,
|
||||
RedirectURIs: nil,
|
||||
GrantTypes: nil,
|
||||
ResponseTypes: nil,
|
||||
Scopes: nil,
|
||||
Audience: nil,
|
||||
Public: true,
|
||||
},
|
||||
JSONWebKeysURI: "where",
|
||||
JSONWebKeys: nil,
|
||||
TokenEndpointAuthMethod: "something",
|
||||
RequestURIs: nil,
|
||||
RequestObjectSigningAlgorithm: "",
|
||||
TokenEndpointAuthSigningAlgorithm: "",
|
||||
},
|
||||
RequestedScope: nil,
|
||||
GrantedScope: nil,
|
||||
Form: url.Values{"key": []string{"val"}},
|
||||
Session: &openid.DefaultSession{
|
||||
Claims: nil,
|
||||
Headers: nil,
|
||||
ExpiresAt: nil,
|
||||
Username: "snorlax",
|
||||
Subject: "panda",
|
||||
},
|
||||
RequestedAudience: nil,
|
||||
GrantedAudience: nil,
|
||||
},
|
||||
}
|
||||
err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request)
|
||||
require.NoError(t, err)
|
||||
|
||||
newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, request, newRequest)
|
||||
|
||||
return storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature")
|
||||
},
|
||||
wantActions: []coretesting.Action{
|
||||
coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authorization-codes",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"active":true,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authorization-codes",
|
||||
}),
|
||||
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"),
|
||||
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4"),
|
||||
coretesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authorization-codes",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authorization-codes",
|
||||
}),
|
||||
},
|
||||
wantSecrets: []corev1.Secret{
|
||||
{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authorization-codes-pwu5zs7lekbhnln2w4",
|
||||
Namespace: namespace,
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authorization-codes",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"active":false,"request":{"responseTypes":["not-code"],"redirectUri":{"Scheme":"","Opaque":"weee","User":{},"Host":"","Path":"/callback","RawPath":"","ForceQuery":false,"RawQuery":"","Fragment":"","RawFragment":""},"state":"stated","handledResponseTypes":["not-type"],"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authorization-codes",
|
||||
},
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client := fake.NewSimpleClientset()
|
||||
if tt.mocks != nil {
|
||||
tt.mocks(t, client)
|
||||
}
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
err := tt.run(t, storage)
|
||||
|
||||
require.Equal(t, tt.wantErr, errString(err))
|
||||
require.Equal(t, tt.wantActions, client.Actions())
|
||||
|
||||
actualSecrets, err := secrets.List(ctx, metav1.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantSecrets, actualSecrets.Items)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func errString(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session.
|
||||
// It will detect any changes to fosite.AuthorizeRequest and guarantees that all interface types have concrete implementations.
|
||||
func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) {
|
||||
validSession := NewValidEmptyAuthorizeCodeSession()
|
||||
|
||||
// sanity check our valid session
|
||||
extractedRequest, err := validateAndExtractAuthorizeRequest(validSession.Request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, validSession.Request, extractedRequest)
|
||||
|
||||
// checked above
|
||||
defaultClient := validSession.Request.Request.Client.(*fosite.DefaultOpenIDConnectClient)
|
||||
defaultSession := validSession.Request.Request.Session.(*openid.DefaultSession)
|
||||
|
||||
// makes it easier to use a raw string
|
||||
replacer := strings.NewReplacer("`", "a")
|
||||
randString := func(c fuzz.Continue) string {
|
||||
for {
|
||||
s := c.RandString()
|
||||
if len(s) == 0 {
|
||||
continue // skip empty string
|
||||
}
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
}
|
||||
|
||||
// deterministic fuzzing of fosite.AuthorizeRequest
|
||||
f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs(
|
||||
// these functions guarantee that these are the only interface types we need to fill out
|
||||
// if fosite.AuthorizeRequest changes to add more, the fuzzer will panic
|
||||
func(fc *fosite.Client, c fuzz.Continue) {
|
||||
c.Fuzz(defaultClient)
|
||||
*fc = defaultClient
|
||||
},
|
||||
func(fs *fosite.Session, c fuzz.Continue) {
|
||||
c.Fuzz(defaultSession)
|
||||
*fs = defaultSession
|
||||
},
|
||||
|
||||
// these types contain an interface{} that we need to handle
|
||||
// this is safe because we explicitly provide the openid.DefaultSession concrete type
|
||||
func(value *map[string]interface{}, c fuzz.Continue) {
|
||||
// cover all the JSON data types just in case
|
||||
*value = map[string]interface{}{
|
||||
randString(c): float64(c.Intn(1 << 32)),
|
||||
randString(c): map[string]interface{}{
|
||||
randString(c): []interface{}{float64(c.Intn(1 << 32))},
|
||||
randString(c): map[string]interface{}{
|
||||
randString(c): nil,
|
||||
randString(c): map[string]interface{}{
|
||||
randString(c): c.RandBool(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
// JWK contains an interface{} Key that we need to handle
|
||||
// this is safe because JWK explicitly implements JSON marshalling and unmarshalling
|
||||
func(jwk *jose.JSONWebKey, c fuzz.Continue) {
|
||||
key, _, err := ed25519.GenerateKey(c)
|
||||
require.NoError(t, err)
|
||||
jwk.Key = key
|
||||
|
||||
// set these fields to make the .Equal comparison work
|
||||
jwk.Certificates = []*x509.Certificate{}
|
||||
jwk.CertificatesURL = &url.URL{}
|
||||
jwk.CertificateThumbprintSHA1 = []byte{}
|
||||
jwk.CertificateThumbprintSHA256 = []byte{}
|
||||
},
|
||||
|
||||
// set this to make the .Equal comparison work
|
||||
// this is safe because Time explicitly implements JSON marshalling and unmarshalling
|
||||
func(tp *time.Time, c fuzz.Continue) {
|
||||
*tp = time.Unix(c.Int63n(1<<32), c.Int63n(1<<32)).UTC()
|
||||
},
|
||||
|
||||
// make random strings that do not contain any ` characters
|
||||
func(s *string, c fuzz.Continue) {
|
||||
*s = randString(c)
|
||||
},
|
||||
// handle string type alias
|
||||
func(s *fosite.TokenType, c fuzz.Continue) {
|
||||
*s = fosite.TokenType(randString(c))
|
||||
},
|
||||
// handle string type alias
|
||||
func(s *fosite.Arguments, c fuzz.Continue) {
|
||||
n := c.Intn(3) + 1 // 1 to 3 items
|
||||
arguments := make(fosite.Arguments, n)
|
||||
for i := range arguments {
|
||||
arguments[i] = randString(c)
|
||||
}
|
||||
*s = arguments
|
||||
},
|
||||
)
|
||||
|
||||
f.Fuzz(validSession)
|
||||
|
||||
const name = "fuzz" // value is irrelevant
|
||||
ctx := context.Background()
|
||||
secrets := fake.NewSimpleClientset().CoreV1().Secrets(name)
|
||||
storage := New(secrets)
|
||||
|
||||
// issue a create using the fuzzed request to confirm that marshalling works
|
||||
err = storage.CreateAuthorizeCodeSession(ctx, name, validSession.Request)
|
||||
require.NoError(t, err)
|
||||
|
||||
// retrieve a copy of the fuzzed request from storage to confirm that unmarshalling works
|
||||
newRequest, err := storage.GetAuthorizeCodeSession(ctx, name, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// the fuzzed request and the copy from storage should be exactly the same
|
||||
require.Equal(t, validSession.Request, newRequest)
|
||||
|
||||
secretList, err := secrets.List(ctx, metav1.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretList.Items, 1)
|
||||
authorizeCodeSessionJSONFromStorage := string(secretList.Items[0].Data["pinniped-storage-data"])
|
||||
|
||||
// set these to match CreateAuthorizeCodeSession so that .JSONEq works
|
||||
validSession.Active = true
|
||||
validSession.Version = "1"
|
||||
|
||||
validSessionJSONBytes, err := json.MarshalIndent(validSession, "", "\t")
|
||||
require.NoError(t, err)
|
||||
authorizeCodeSessionJSONFromFuzzing := string(validSessionJSONBytes)
|
||||
|
||||
// the fuzzed session and storage session should have identical JSON
|
||||
require.JSONEq(t, authorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromStorage)
|
||||
|
||||
// while the fuzzer will panic if AuthorizeRequest changes in a way that cannot be fuzzed,
|
||||
// if it adds a new field that can be fuzzed, this check will fail
|
||||
// thus if AuthorizeRequest changes, we will detect it here (though we could possibly miss an omitempty field)
|
||||
require.Equal(t, ExpectedAuthorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromFuzzing)
|
||||
}
|
@ -16,11 +16,11 @@ import (
|
||||
|
||||
"go.pinniped.dev/internal/constable"
|
||||
"go.pinniped.dev/internal/crud"
|
||||
"go.pinniped.dev/internal/fositestorage"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrInvalidAuthorizeRequestType = constable.Error("authorization request must be of type fosite.AuthorizeRequest")
|
||||
ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must not be nil")
|
||||
ErrInvalidAuthorizeRequestData = constable.Error("authorization request data must be present")
|
||||
ErrInvalidAuthorizeRequestVersion = constable.Error("authorization request data has wrong version")
|
||||
|
||||
authorizeCodeStorageVersion = "1"
|
||||
@ -34,25 +34,24 @@ type authorizeCodeStorage struct {
|
||||
|
||||
type AuthorizeCodeSession struct {
|
||||
Active bool `json:"active"`
|
||||
Request *fosite.AuthorizeRequest `json:"request"`
|
||||
Request *fosite.Request `json:"request"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func New(secrets corev1client.SecretInterface) oauth2.AuthorizeCodeStorage {
|
||||
return &authorizeCodeStorage{storage: crud.New("authorization-codes", secrets)}
|
||||
return &authorizeCodeStorage{storage: crud.New("authcode", secrets)}
|
||||
}
|
||||
|
||||
func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) error {
|
||||
// this conversion assumes that we do not wrap the default type in any way
|
||||
// This conversion assumes that we do not wrap the default type in any way
|
||||
// i.e. we use the default fosite.OAuth2Provider.NewAuthorizeRequest implementation
|
||||
// note that because this type is serialized and stored in Kube, we cannot easily change the implementation later
|
||||
// TODO hydra uses the fosite.Request struct and ignores the extra fields in fosite.AuthorizeRequest
|
||||
request, err := validateAndExtractAuthorizeRequest(requester)
|
||||
request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO hydra stores specific fields from the requester
|
||||
// Note, in case it is helpful, that Hydra stores specific fields from the requester:
|
||||
// request ID
|
||||
// requestedAt
|
||||
// OAuth client ID
|
||||
@ -70,12 +69,11 @@ func (a *authorizeCodeStorage) CreateAuthorizeCodeSession(ctx context.Context, s
|
||||
}
|
||||
|
||||
func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) {
|
||||
// TODO hydra uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes
|
||||
|
||||
// TODO hydra gets the client from its DB as a concrete type via client ID,
|
||||
// the hydra memory client just validates that the client ID exists
|
||||
|
||||
// TODO hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length
|
||||
// Note, in case it is helpful, that Hydra:
|
||||
// - uses the incoming fosite.Session to provide the type needed to json.Unmarshal their session bytes
|
||||
// - gets the client from its DB as a concrete type via client ID, the hydra memory client just validates that the
|
||||
// client ID exists
|
||||
// - hydra uses the sha512.Sum384 hash of signature when using JWT as access token to reduce length
|
||||
|
||||
session, _, err := a.getSession(ctx, signature)
|
||||
|
||||
@ -88,8 +86,6 @@ func (a *authorizeCodeStorage) GetAuthorizeCodeSession(ctx context.Context, sign
|
||||
}
|
||||
|
||||
func (a *authorizeCodeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) error {
|
||||
// TODO write garbage collector for these codes
|
||||
|
||||
session, rv, err := a.getSession(ctx, signature)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -123,7 +119,7 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string)
|
||||
ErrInvalidAuthorizeRequestVersion, signature, version, authorizeCodeStorageVersion)
|
||||
}
|
||||
|
||||
if session.Request == nil {
|
||||
if session.Request.ID == "" {
|
||||
return nil, "", fmt.Errorf("malformed authorization code session for %s: %w", signature, ErrInvalidAuthorizeRequestData)
|
||||
}
|
||||
|
||||
@ -137,31 +133,13 @@ func (a *authorizeCodeStorage) getSession(ctx context.Context, signature string)
|
||||
|
||||
func NewValidEmptyAuthorizeCodeSession() *AuthorizeCodeSession {
|
||||
return &AuthorizeCodeSession{
|
||||
Request: &fosite.AuthorizeRequest{
|
||||
Request: fosite.Request{
|
||||
Request: &fosite.Request{
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
Session: &openid.DefaultSession{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.AuthorizeRequest, error) {
|
||||
request, ok1 := requester.(*fosite.AuthorizeRequest)
|
||||
if !ok1 {
|
||||
return nil, ErrInvalidAuthorizeRequestType
|
||||
}
|
||||
_, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient)
|
||||
_, ok3 := request.Session.(*openid.DefaultSession)
|
||||
|
||||
valid := ok2 && ok3
|
||||
if !valid {
|
||||
return nil, ErrInvalidAuthorizeRequestType
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
var _ interface {
|
||||
Is(error) bool
|
||||
Unwrap() error
|
||||
@ -189,59 +167,37 @@ func (e *errSerializationFailureWithCause) Error() string {
|
||||
const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{
|
||||
"active": true,
|
||||
"request": {
|
||||
"responseTypes": [
|
||||
"¥Îʒ襧.ɕ7崛瀇莒AȒ[ɠ牐7#$ɭ",
|
||||
".5ȿELj9ûF済(D疻翋膗",
|
||||
"螤Yɫüeɯ紤邥翔勋\\RBʒ;-"
|
||||
],
|
||||
"redirectUri": {
|
||||
"Scheme": "ħesƻU赒M喦_ģ",
|
||||
"Opaque": "Ġ/_章Ņ缘T蝟NJ儱礹燃ɢ",
|
||||
"User": {},
|
||||
"Host": "ȳ4螘Wo",
|
||||
"Path": "}i{",
|
||||
"RawPath": "5Dža丝eF0eė鱊hǒx蔼Q",
|
||||
"ForceQuery": true,
|
||||
"RawQuery": "熤1bbWV",
|
||||
"Fragment": "ȋc剠鏯ɽÿ¸",
|
||||
"RawFragment": "qƤ"
|
||||
},
|
||||
"state": "@n,x竘Şǥ嗾稀'ã击漰怼禝穞梠Ǫs",
|
||||
"handledResponseTypes": [
|
||||
"m\"e尚鬞ƻɼ抹d誉y鿜Ķ"
|
||||
],
|
||||
"id": "ō澩ć|3U2Ǜl霨ǦǵpƉ",
|
||||
"requestedAt": "1989-11-05T22:02:31.105295894Z",
|
||||
"id": "嫎l蟲aƖ啘艿",
|
||||
"requestedAt": "2082-11-10T18:36:11.627253638Z",
|
||||
"client": {
|
||||
"id": "[:c顎疻紵D",
|
||||
"client_secret": "mQ==",
|
||||
"id": "!ſɄĈp[述齛ʘUȻ.5ȿE",
|
||||
"client_secret": "UQ==",
|
||||
"redirect_uris": [
|
||||
"恣S@T嵇LJV,Æ櫔袆鋹奘菲",
|
||||
"ãƻʚ肈ą8O+a駣Ʉɼk瘸'鴵y"
|
||||
"ǣ珑 ʑ飶畛Ȳ螤Yɫüeɯ紤邥翔勋\\",
|
||||
"Bʒ;",
|
||||
"鿃攴Ųęʍ鎾ʦ©cÏN,Ġ/_"
|
||||
],
|
||||
"grant_types": [
|
||||
".湆ê\"唐",
|
||||
"曎餄FxD溪躲珫ÈşɜȨû臓嬣\"ǃŤz"
|
||||
"憉sHĒ尥窘挼Ŀʼn"
|
||||
],
|
||||
"response_types": [
|
||||
"Ņʘʟ車sʊ儓JǐŪɺǣy|耑ʄ"
|
||||
"4",
|
||||
"ʄÔ@}i{絧遗Ū^ȝĸ谋Vʋ鱴閇T"
|
||||
],
|
||||
"scopes": [
|
||||
"Ą",
|
||||
"萙Į(潶饏熞ĝƌĆ1",
|
||||
"əȤ4Į筦p煖鵄$睱奐耡q"
|
||||
"R鴝順諲ŮŚ节ȭŀȋc剠鏯ɽÿ¸"
|
||||
],
|
||||
"audience": [
|
||||
"Ʃǣ鿫/Ò敫ƤV"
|
||||
"Ƥ"
|
||||
],
|
||||
"public": true,
|
||||
"jwks_uri": "ȩđ[嬧鱒Ȁ彆媚杨嶒ĤG",
|
||||
"jwks_uri": "BA瘪囷ɫCʄɢ雐譄uée'",
|
||||
"jwks": {
|
||||
"keys": [
|
||||
{
|
||||
"kty": "OKP",
|
||||
"crv": "Ed25519",
|
||||
"x": "JmA-6KpjzqKu0lq9OiB6ORL4s2UzBFPsE1hm6vESeXM",
|
||||
"x": "nK9xgX_iN7u3u_i8YOO7ZRT_WK028Vd_nhtsUu7Eo6E",
|
||||
"x5u": {
|
||||
"Scheme": "",
|
||||
"Opaque": "",
|
||||
@ -258,24 +214,7 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{
|
||||
{
|
||||
"kty": "OKP",
|
||||
"crv": "Ed25519",
|
||||
"x": "LbRC1_3HEe5o7Japk9jFp3_7Ou7Gi2gpqrVrIi0eLDQ",
|
||||
"x5u": {
|
||||
"Scheme": "",
|
||||
"Opaque": "",
|
||||
"User": null,
|
||||
"Host": "",
|
||||
"Path": "",
|
||||
"RawPath": "",
|
||||
"ForceQuery": false,
|
||||
"RawQuery": "",
|
||||
"Fragment": "",
|
||||
"RawFragment": ""
|
||||
}
|
||||
},
|
||||
{
|
||||
"kty": "OKP",
|
||||
"crv": "Ed25519",
|
||||
"x": "Ovk4DF8Yn3mkULuTqnlGJxFnKGu9EL6Xcf2Nql9lK3c",
|
||||
"x": "UbbswQgzWhfGCRlwQmMp6fw_HoIoqkIaKT-2XN2fuYU",
|
||||
"x5u": {
|
||||
"Scheme": "",
|
||||
"Opaque": "",
|
||||
@ -291,91 +230,95 @@ const ExpectedAuthorizeCodeSessionJSONFromFuzzing = `{
|
||||
}
|
||||
]
|
||||
},
|
||||
"token_endpoint_auth_method": "\u0026(K鵢Kj ŏ9Q韉Ķ%嶑輫ǘ(",
|
||||
"token_endpoint_auth_method": "ŚǗƳȕ暭Q0ņP羾,塐",
|
||||
"request_uris": [
|
||||
":",
|
||||
"6ě#嫀^xz Ū胧r"
|
||||
"lj翻LH^俤µDzɹ@©|\u003eɃ",
|
||||
"[:c顎疻紵D"
|
||||
],
|
||||
"request_object_signing_alg": "^¡!犃ĹĐJí¿ō擫ų懫砰¿",
|
||||
"token_endpoint_auth_signing_alg": "ƈŮå"
|
||||
"request_object_signing_alg": "m1Ì恣S@T嵇LJV,Æ櫔袆鋹奘",
|
||||
"token_endpoint_auth_signing_alg": "Fãƻʚ肈ą8O+a駣"
|
||||
},
|
||||
"scopes": [
|
||||
"阃.Ù頀ʌGa皶竇瞍涘¹",
|
||||
"ȽŮ切衖庀ŰŒ矠",
|
||||
"楓)馻řĝǕ菸Tĕ1伞柲\u003c\"ʗȆ\\雤"
|
||||
"ɼk瘸'鴵yſǮŁ±\u003eFA曎餄FxD溪",
|
||||
"綻N镪p赌h%桙dĽ"
|
||||
],
|
||||
"grantedScopes": [
|
||||
"ơ鮫R嫁ɍUƞ9+u!Ȱ",
|
||||
"}Ă岜"
|
||||
"癗E]Ņʘʟ車s"
|
||||
],
|
||||
"form": {
|
||||
"旸Ť/Õ薝隧;綡,鼞纂=": [
|
||||
"[滮]憀",
|
||||
"3\u003eÙœ蓄UK嗤眇疟Țƒ1v¸KĶ"
|
||||
"蹬器ķ8ŷ萒寎廭#疶昄Ą-Ƃƞ轵": [
|
||||
"熞ĝƌĆ1ȇyǴ濎=Tʉȼʁŀ\u003c",
|
||||
"耡q戨稞R÷mȵg釽[ƞ@",
|
||||
"đ[嬧鱒Ȁ彆媚杨嶒ĤGÀ吧Lŷ"
|
||||
],
|
||||
"餟": [
|
||||
"蒍z\u0026(K鵢Kj ŏ9Q韉Ķ%",
|
||||
"輫ǘ(¨Ƞ亱6ě#嫀^xz ",
|
||||
"@耢ɝ^¡!犃ĹĐJí¿ō擫"
|
||||
]
|
||||
},
|
||||
"session": {
|
||||
"Claims": {
|
||||
"JTI": "};Ų斻遟a衪荖舃",
|
||||
"Issuer": "芠顋敀拲h蝺$!",
|
||||
"Subject": "}j%(=ſ氆]垲莲顇",
|
||||
"JTI": "懫砰¿C筽娴ƓaPu镈賆ŗɰ",
|
||||
"Issuer": "皶竇瞍涘¹焕iǢǽɽĺŧ",
|
||||
"Subject": "矠M6ɡǜg炾ʙ$%o6肿Ȫ",
|
||||
"Audience": [
|
||||
"彑V\\廳蟕Țǡ蔯ʠ浵Ī龉磈螖畭5",
|
||||
"渇Ȯʕc"
|
||||
"ƌÙ鯆GQơ鮫R嫁ɍUƞ9+u!Ȱ踾$"
|
||||
],
|
||||
"Nonce": "Ǖ=rlƆ褡{ǏS",
|
||||
"ExpiresAt": "1975-11-17T14:21:34.205609651Z",
|
||||
"IssuedAt": "2104-07-03T15:40:03.66710966Z",
|
||||
"RequestedAt": "2031-05-18T05:14:19.449350555Z",
|
||||
"AuthTime": "2018-01-27T07:55:06.056862114Z",
|
||||
"AccessTokenHash": "鹰肁躧",
|
||||
"AuthenticationContextClassReference": "}Ɇ",
|
||||
"AuthenticationMethodsReference": "DQh:uȣ",
|
||||
"CodeHash": "ɘȏıȒ諃龟",
|
||||
"Nonce": "us旸Ť/Õ薝隧;綡,鼞",
|
||||
"ExpiresAt": "2065-11-30T13:47:03.613000626Z",
|
||||
"IssuedAt": "1976-02-22T09:57:20.479850437Z",
|
||||
"RequestedAt": "2016-04-13T04:18:53.648949323Z",
|
||||
"AuthTime": "2098-07-12T04:38:54.034043015Z",
|
||||
"AccessTokenHash": "滮]",
|
||||
"AuthenticationContextClassReference": "°3\u003eÙ",
|
||||
"AuthenticationMethodsReference": "k?µ鱔ǤÂ",
|
||||
"CodeHash": "Țƒ1v¸KĶ跭};",
|
||||
"Extra": {
|
||||
"a": {
|
||||
"^i臏f恡ƨ彮": {
|
||||
"DĘ敨ýÏʥZq7烱藌\\": null,
|
||||
"V": {
|
||||
"őŧQĝ微'X焌襱ǭɕņ殥!_n": false
|
||||
"=ſ氆": {
|
||||
"Ƿī,廖ʡ彑V\\廳蟕Ț": [
|
||||
843216989
|
||||
],
|
||||
"蔯ʠ浵Ī": {
|
||||
"H\"nǕ=rlƆ褡{ǏSȳŅ": {
|
||||
"Žg": false
|
||||
},
|
||||
"枱鰧ɛ鸁A渇": null
|
||||
}
|
||||
},
|
||||
"Ż猁": [
|
||||
1706822246
|
||||
]
|
||||
},
|
||||
"Ò椪)ɫqň2搞Ŀ高摠鲒鿮禗O": 1233332227
|
||||
"斻遟a衪荖舃9闄岈锘肺ńʥƕU}j%": 2520197933
|
||||
}
|
||||
},
|
||||
"Headers": {
|
||||
"Extra": {
|
||||
"?戋璖$9\u0026": {
|
||||
"µcɕ餦ÑEǰ哤癨浦浏1R": [
|
||||
3761201123
|
||||
],
|
||||
"頓ć§蚲6rǦ\u003cqċ": {
|
||||
"Łʀ§ȏœɽDz斡冭ȸěaʜD捛?½ʀ+": null,
|
||||
"ɒúIJ誠ƉyÖ.峷1藍殙菥趏": {
|
||||
"jHȬȆ#)\u003cX": true
|
||||
}
|
||||
"熒ɘȏıȒ諃龟ŴŠ'耐Ƭ扵ƹ玄ɕwL": {
|
||||
"ýÏʥZq7烱藌\\捀¿őŧQ": {
|
||||
"微'X焌襱ǭɕņ殥!_": null,
|
||||
"荇届UȚ?戋璖$9\u00269舋": {
|
||||
"ɕ餦ÑEǰ哤癨浦浏1Rk頓ć§蚲6": true
|
||||
}
|
||||
},
|
||||
"U": 1354158262
|
||||
"鲒鿮禗O暒aJP鐜?ĮV嫎h譭ȉ]DĘ": [
|
||||
954647573
|
||||
]
|
||||
},
|
||||
"皩Ƭ}Ɇ.雬Ɨ´唁": 1572524915
|
||||
}
|
||||
},
|
||||
"ExpiresAt": {
|
||||
"\"嘬ȹĹaó剺撱Ȱ": "1985-09-09T04:35:40.533197189Z",
|
||||
"ʆ\u003e": "1998-08-07T05:37:11.759718906Z",
|
||||
"柏ʒ鴙*鸆偡Ȓ肯Ûx": "2036-12-19T06:36:14.414805124Z"
|
||||
"\u003cqċ譈8ŪɎP绿MÅ": "2031-10-18T22:07:34.950803105Z",
|
||||
"ȸěaʜD捛?½ʀ+Ċ偢镳ʬÍɷȓ\u003c": "2049-05-13T15:27:20.968432454Z"
|
||||
},
|
||||
"Username": "qmʎaðƠ绗ʢ緦Hū",
|
||||
"Subject": "屾Ê窢ɋ鄊qɠ谫ǯǵƕ牀1鞊\\ȹ)"
|
||||
"Username": "1藍殙菥趏酱Nʎ\u0026^横懋ƶ峦Fïȫƅw",
|
||||
"Subject": "檾ĩĆ爨4犹|v炩f柏ʒ鴙*鸆偡"
|
||||
},
|
||||
"requestedAudience": [
|
||||
"鉍商OɄƣ圔,xĪɏV鵅砍"
|
||||
"肯Ûx穞Ƀ",
|
||||
"ź蕴3ǐ薝Ƅ腲=ʐ诂鱰屾Ê窢ɋ鄊qɠ谫"
|
||||
],
|
||||
"grantedAudience": [
|
||||
"C笜嚯\u003cǐšɚĀĥʋ6鉅\\þc涎漄Ɨ腼"
|
||||
"ǵƕ牀1鞊\\ȹ)}鉍商OɄƣ圔,xĪ",
|
||||
"悾xn冏裻摼0Ʈ蚵Ȼ塕»£#稏扟X"
|
||||
]
|
||||
},
|
||||
"version": "1"
|
@ -0,0 +1,401 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authorizationcode
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
apierrors "k8s.io/apimachinery/pkg/api/errors"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
|
||||
fuzz "github.com/google/gofuzz"
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
kubetesting "k8s.io/client-go/testing"
|
||||
|
||||
"go.pinniped.dev/internal/fositestorage"
|
||||
)
|
||||
|
||||
const namespace = "test-ns"
|
||||
|
||||
func TestAuthorizationCodeStorage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
secretsGVR := schema.GroupVersionResource{
|
||||
Group: "",
|
||||
Version: "v1",
|
||||
Resource: "secrets",
|
||||
}
|
||||
|
||||
wantActions := []kubetesting.Action{
|
||||
kubetesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authcode",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"active":true,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authcode",
|
||||
}),
|
||||
kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"),
|
||||
kubetesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-authcode-pwu5zs7lekbhnln2w4"),
|
||||
kubetesting.NewUpdateAction(secretsGVR, namespace, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authcode",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"active":false,"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authcode",
|
||||
}),
|
||||
}
|
||||
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
request := &fosite.Request{
|
||||
ID: "abcd-1",
|
||||
RequestedAt: time.Time{},
|
||||
Client: &fosite.DefaultOpenIDConnectClient{
|
||||
DefaultClient: &fosite.DefaultClient{
|
||||
ID: "pinny",
|
||||
Secret: nil,
|
||||
RedirectURIs: nil,
|
||||
GrantTypes: nil,
|
||||
ResponseTypes: nil,
|
||||
Scopes: nil,
|
||||
Audience: nil,
|
||||
Public: true,
|
||||
},
|
||||
JSONWebKeysURI: "where",
|
||||
JSONWebKeys: nil,
|
||||
TokenEndpointAuthMethod: "something",
|
||||
RequestURIs: nil,
|
||||
RequestObjectSigningAlgorithm: "",
|
||||
TokenEndpointAuthSigningAlgorithm: "",
|
||||
},
|
||||
RequestedScope: nil,
|
||||
GrantedScope: nil,
|
||||
Form: url.Values{"key": []string{"val"}},
|
||||
Session: &openid.DefaultSession{
|
||||
Claims: nil,
|
||||
Headers: nil,
|
||||
ExpiresAt: nil,
|
||||
Username: "snorlax",
|
||||
Subject: "panda",
|
||||
},
|
||||
RequestedAudience: nil,
|
||||
GrantedAudience: nil,
|
||||
}
|
||||
err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request)
|
||||
require.NoError(t, err)
|
||||
|
||||
newRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, request, newRequest)
|
||||
|
||||
err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, wantActions, client.Actions())
|
||||
|
||||
// Doing a Get on an invalidated session should still return the session, but also return an error.
|
||||
invalidatedRequest, err := storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
|
||||
require.EqualError(t, err, "authorization code session for fancy-signature has already been used: Authorization code has ben invalidated")
|
||||
require.Equal(t, "abcd-1", invalidatedRequest.GetID())
|
||||
}
|
||||
|
||||
func TestGetNotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
_, notFoundErr := storage.GetAuthorizeCodeSession(ctx, "non-existent-signature", nil)
|
||||
require.EqualError(t, notFoundErr, "not_found")
|
||||
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestInvalidateWhenNotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
notFoundErr := storage.InvalidateAuthorizeCodeSession(ctx, "non-existent-signature")
|
||||
require.EqualError(t, notFoundErr, "not_found")
|
||||
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestInvalidateWhenConflictOnUpdateHappens(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
client.PrependReactor("update", "secrets", func(_ kubetesting.Action) (bool, runtime.Object, error) {
|
||||
return true, nil, apierrors.NewConflict(schema.GroupResource{
|
||||
Group: "",
|
||||
Resource: "secrets",
|
||||
}, "some-secret-name", fmt.Errorf("there was a conflict"))
|
||||
})
|
||||
|
||||
request := &fosite.Request{
|
||||
ID: "some-request-id",
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
Session: &openid.DefaultSession{},
|
||||
}
|
||||
err := storage.CreateAuthorizeCodeSession(ctx, "fancy-signature", request)
|
||||
require.NoError(t, err)
|
||||
err = storage.InvalidateAuthorizeCodeSession(ctx, "fancy-signature")
|
||||
require.EqualError(t, err, `The request could not be completed due to concurrent access: failed to update authcode for signature fancy-signature at resource version : Operation cannot be fulfilled on secrets "some-secret-name": there was a conflict`)
|
||||
}
|
||||
|
||||
func TestWrongVersion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
secret := &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authcode",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version", "active": true}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authcode",
|
||||
}
|
||||
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
|
||||
|
||||
require.EqualError(t, err, "authorization request data has wrong version: authorization code session for fancy-signature has version not-the-right-version instead of 1")
|
||||
}
|
||||
|
||||
func TestNilSessionRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
secret := &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-authcode-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "authcode",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value", "version":"1", "active": true}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/authcode",
|
||||
}
|
||||
|
||||
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.GetAuthorizeCodeSession(ctx, "fancy-signature", nil)
|
||||
require.EqualError(t, err, "malformed authorization code session for fancy-signature: authorization request data must be present")
|
||||
}
|
||||
|
||||
func TestCreateWithNilRequester(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", nil)
|
||||
require.EqualError(t, err, "requester must be of type fosite.Request")
|
||||
}
|
||||
|
||||
func TestCreateWithWrongRequesterDataTypes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
request := &fosite.Request{
|
||||
Session: nil,
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
}
|
||||
err := storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request)
|
||||
require.EqualError(t, err, "requester's session must be of type openid.DefaultSession")
|
||||
|
||||
request = &fosite.Request{
|
||||
Session: &openid.DefaultSession{},
|
||||
Client: nil,
|
||||
}
|
||||
err = storage.CreateAuthorizeCodeSession(ctx, "signature-doesnt-matter", request)
|
||||
require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient")
|
||||
}
|
||||
|
||||
// TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession asserts that we can correctly round trip our authorize code session.
|
||||
// It will detect any changes to fosite.AuthorizeRequest and guarantees that all interface types have concrete implementations.
|
||||
func TestFuzzAndJSONNewValidEmptyAuthorizeCodeSession(t *testing.T) {
|
||||
validSession := NewValidEmptyAuthorizeCodeSession()
|
||||
|
||||
// sanity check our valid session
|
||||
extractedRequest, err := fositestorage.ValidateAndExtractAuthorizeRequest(validSession.Request)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, validSession.Request, extractedRequest)
|
||||
|
||||
// checked above
|
||||
defaultClient := validSession.Request.Client.(*fosite.DefaultOpenIDConnectClient)
|
||||
defaultSession := validSession.Request.Session.(*openid.DefaultSession)
|
||||
|
||||
// makes it easier to use a raw string
|
||||
replacer := strings.NewReplacer("`", "a")
|
||||
randString := func(c fuzz.Continue) string {
|
||||
for {
|
||||
s := c.RandString()
|
||||
if len(s) == 0 {
|
||||
continue // skip empty string
|
||||
}
|
||||
return replacer.Replace(s)
|
||||
}
|
||||
}
|
||||
|
||||
// deterministic fuzzing of fosite.Request
|
||||
f := fuzz.New().RandSource(rand.NewSource(1)).NilChance(0).NumElements(1, 3).Funcs(
|
||||
// these functions guarantee that these are the only interface types we need to fill out
|
||||
// if fosite.Request changes to add more, the fuzzer will panic
|
||||
func(fc *fosite.Client, c fuzz.Continue) {
|
||||
c.Fuzz(defaultClient)
|
||||
*fc = defaultClient
|
||||
},
|
||||
func(fs *fosite.Session, c fuzz.Continue) {
|
||||
c.Fuzz(defaultSession)
|
||||
*fs = defaultSession
|
||||
},
|
||||
|
||||
// these types contain an interface{} that we need to handle
|
||||
// this is safe because we explicitly provide the openid.DefaultSession concrete type
|
||||
func(value *map[string]interface{}, c fuzz.Continue) {
|
||||
// cover all the JSON data types just in case
|
||||
*value = map[string]interface{}{
|
||||
randString(c): float64(c.Intn(1 << 32)),
|
||||
randString(c): map[string]interface{}{
|
||||
randString(c): []interface{}{float64(c.Intn(1 << 32))},
|
||||
randString(c): map[string]interface{}{
|
||||
randString(c): nil,
|
||||
randString(c): map[string]interface{}{
|
||||
randString(c): c.RandBool(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
// JWK contains an interface{} Key that we need to handle
|
||||
// this is safe because JWK explicitly implements JSON marshalling and unmarshalling
|
||||
func(jwk *jose.JSONWebKey, c fuzz.Continue) {
|
||||
key, _, err := ed25519.GenerateKey(c)
|
||||
require.NoError(t, err)
|
||||
jwk.Key = key
|
||||
|
||||
// set these fields to make the .Equal comparison work
|
||||
jwk.Certificates = []*x509.Certificate{}
|
||||
jwk.CertificatesURL = &url.URL{}
|
||||
jwk.CertificateThumbprintSHA1 = []byte{}
|
||||
jwk.CertificateThumbprintSHA256 = []byte{}
|
||||
},
|
||||
|
||||
// set this to make the .Equal comparison work
|
||||
// this is safe because Time explicitly implements JSON marshalling and unmarshalling
|
||||
func(tp *time.Time, c fuzz.Continue) {
|
||||
*tp = time.Unix(c.Int63n(1<<32), c.Int63n(1<<32)).UTC()
|
||||
},
|
||||
|
||||
// make random strings that do not contain any ` characters
|
||||
func(s *string, c fuzz.Continue) {
|
||||
*s = randString(c)
|
||||
},
|
||||
// handle string type alias
|
||||
func(s *fosite.TokenType, c fuzz.Continue) {
|
||||
*s = fosite.TokenType(randString(c))
|
||||
},
|
||||
// handle string type alias
|
||||
func(s *fosite.Arguments, c fuzz.Continue) {
|
||||
n := c.Intn(3) + 1 // 1 to 3 items
|
||||
arguments := make(fosite.Arguments, n)
|
||||
for i := range arguments {
|
||||
arguments[i] = randString(c)
|
||||
}
|
||||
*s = arguments
|
||||
},
|
||||
)
|
||||
|
||||
f.Fuzz(validSession)
|
||||
|
||||
const name = "fuzz" // value is irrelevant
|
||||
ctx := context.Background()
|
||||
secrets := fake.NewSimpleClientset().CoreV1().Secrets(name)
|
||||
storage := New(secrets)
|
||||
|
||||
// issue a create using the fuzzed request to confirm that marshalling works
|
||||
err = storage.CreateAuthorizeCodeSession(ctx, name, validSession.Request)
|
||||
require.NoError(t, err)
|
||||
|
||||
// retrieve a copy of the fuzzed request from storage to confirm that unmarshalling works
|
||||
newRequest, err := storage.GetAuthorizeCodeSession(ctx, name, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// the fuzzed request and the copy from storage should be exactly the same
|
||||
require.Equal(t, validSession.Request, newRequest)
|
||||
|
||||
secretList, err := secrets.List(ctx, metav1.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, secretList.Items, 1)
|
||||
authorizeCodeSessionJSONFromStorage := string(secretList.Items[0].Data["pinniped-storage-data"])
|
||||
|
||||
// set these to match CreateAuthorizeCodeSession so that .JSONEq works
|
||||
validSession.Active = true
|
||||
validSession.Version = "1"
|
||||
|
||||
validSessionJSONBytes, err := json.MarshalIndent(validSession, "", "\t")
|
||||
require.NoError(t, err)
|
||||
authorizeCodeSessionJSONFromFuzzing := string(validSessionJSONBytes)
|
||||
|
||||
// the fuzzed session and storage session should have identical JSON
|
||||
require.JSONEq(t, authorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromStorage)
|
||||
|
||||
// while the fuzzer will panic if AuthorizeRequest changes in a way that cannot be fuzzed,
|
||||
// if it adds a new field that can be fuzzed, this check will fail
|
||||
// thus if AuthorizeRequest changes, we will detect it here (though we could possibly miss an omitempty field)
|
||||
require.Equal(t, ExpectedAuthorizeCodeSessionJSONFromFuzzing, authorizeCodeSessionJSONFromFuzzing)
|
||||
}
|
34
internal/fositestorage/fositestorage.go
Normal file
34
internal/fositestorage/fositestorage.go
Normal file
@ -0,0 +1,34 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package fositestorage
|
||||
|
||||
import (
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
|
||||
"go.pinniped.dev/internal/constable"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrInvalidRequestType = constable.Error("requester must be of type fosite.Request")
|
||||
ErrInvalidClientType = constable.Error("requester's client must be of type fosite.DefaultOpenIDConnectClient")
|
||||
ErrInvalidSessionType = constable.Error("requester's session must be of type openid.DefaultSession")
|
||||
)
|
||||
|
||||
func ValidateAndExtractAuthorizeRequest(requester fosite.Requester) (*fosite.Request, error) {
|
||||
request, ok1 := requester.(*fosite.Request)
|
||||
if !ok1 {
|
||||
return nil, ErrInvalidRequestType
|
||||
}
|
||||
_, ok2 := request.Client.(*fosite.DefaultOpenIDConnectClient)
|
||||
if !ok2 {
|
||||
return nil, ErrInvalidClientType
|
||||
}
|
||||
_, ok3 := request.Session.(*openid.DefaultSession)
|
||||
if !ok3 {
|
||||
return nil, ErrInvalidSessionType
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
124
internal/fositestorage/openidconnect/openidconnect.go
Normal file
124
internal/fositestorage/openidconnect/openidconnect.go
Normal file
@ -0,0 +1,124 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"k8s.io/apimachinery/pkg/api/errors"
|
||||
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
|
||||
"go.pinniped.dev/internal/constable"
|
||||
"go.pinniped.dev/internal/crud"
|
||||
"go.pinniped.dev/internal/fositestorage"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrInvalidOIDCRequestVersion = constable.Error("oidc request data has wrong version")
|
||||
ErrInvalidOIDCRequestData = constable.Error("oidc request data must be present")
|
||||
ErrMalformedAuthorizationCode = constable.Error("malformed authorization code")
|
||||
|
||||
oidcStorageVersion = "1"
|
||||
)
|
||||
|
||||
var _ openid.OpenIDConnectRequestStorage = &openIDConnectRequestStorage{}
|
||||
|
||||
type openIDConnectRequestStorage struct {
|
||||
storage crud.Storage
|
||||
}
|
||||
|
||||
type session struct {
|
||||
Request *fosite.Request `json:"request"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func New(secrets corev1client.SecretInterface) openid.OpenIDConnectRequestStorage {
|
||||
return &openIDConnectRequestStorage{storage: crud.New("oidc", secrets)}
|
||||
}
|
||||
|
||||
func (a *openIDConnectRequestStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error {
|
||||
signature, err := getSignature(authcode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = a.storage.Create(ctx, signature, &session{Request: request, Version: oidcStorageVersion})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *openIDConnectRequestStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, _ fosite.Requester) (fosite.Requester, error) {
|
||||
signature, err := getSignature(authcode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, _, err := a.getSession(ctx, signature)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.Request, err
|
||||
}
|
||||
|
||||
func (a *openIDConnectRequestStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error {
|
||||
signature, err := getSignature(authcode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return a.storage.Delete(ctx, signature)
|
||||
}
|
||||
|
||||
func (a *openIDConnectRequestStorage) getSession(ctx context.Context, signature string) (*session, string, error) {
|
||||
session := newValidEmptyOIDCSession()
|
||||
rv, err := a.storage.Get(ctx, signature, session)
|
||||
|
||||
if errors.IsNotFound(err) {
|
||||
return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to get oidc session for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
if version := session.Version; version != oidcStorageVersion {
|
||||
return nil, "", fmt.Errorf("%w: oidc session for %s has version %s instead of %s",
|
||||
ErrInvalidOIDCRequestVersion, signature, version, oidcStorageVersion)
|
||||
}
|
||||
|
||||
if session.Request.ID == "" {
|
||||
return nil, "", fmt.Errorf("malformed oidc session for %s: %w", signature, ErrInvalidOIDCRequestData)
|
||||
}
|
||||
|
||||
return session, rv, nil
|
||||
}
|
||||
|
||||
func newValidEmptyOIDCSession() *session {
|
||||
return &session{
|
||||
Request: &fosite.Request{
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
Session: &openid.DefaultSession{},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getSignature(authorizationCode string) (string, error) {
|
||||
split := strings.Split(authorizationCode, ".")
|
||||
|
||||
if len(split) != 2 {
|
||||
return "", ErrMalformedAuthorizationCode
|
||||
}
|
||||
|
||||
return split[1], nil
|
||||
}
|
209
internal/fositestorage/openidconnect/openidconnect_test.go
Normal file
209
internal/fositestorage/openidconnect/openidconnect_test.go
Normal file
@ -0,0 +1,209 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
coretesting "k8s.io/client-go/testing"
|
||||
)
|
||||
|
||||
const namespace = "test-ns"
|
||||
|
||||
func TestOpenIdConnectStorage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
secretsGVR := schema.GroupVersionResource{
|
||||
Group: "",
|
||||
Version: "v1",
|
||||
Resource: "secrets",
|
||||
}
|
||||
|
||||
wantActions := []coretesting.Action{
|
||||
coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "oidc",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/oidc",
|
||||
}),
|
||||
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"),
|
||||
coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-oidc-pwu5zs7lekbhnln2w4"),
|
||||
}
|
||||
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
request := &fosite.Request{
|
||||
ID: "abcd-1",
|
||||
RequestedAt: time.Time{},
|
||||
Client: &fosite.DefaultOpenIDConnectClient{
|
||||
DefaultClient: &fosite.DefaultClient{
|
||||
ID: "pinny",
|
||||
Secret: nil,
|
||||
RedirectURIs: nil,
|
||||
GrantTypes: nil,
|
||||
ResponseTypes: nil,
|
||||
Scopes: nil,
|
||||
Audience: nil,
|
||||
Public: true,
|
||||
},
|
||||
JSONWebKeysURI: "where",
|
||||
JSONWebKeys: nil,
|
||||
TokenEndpointAuthMethod: "something",
|
||||
RequestURIs: nil,
|
||||
RequestObjectSigningAlgorithm: "",
|
||||
TokenEndpointAuthSigningAlgorithm: "",
|
||||
},
|
||||
RequestedScope: nil,
|
||||
GrantedScope: nil,
|
||||
Form: url.Values{"key": []string{"val"}},
|
||||
Session: &openid.DefaultSession{
|
||||
Claims: nil,
|
||||
Headers: nil,
|
||||
ExpiresAt: nil,
|
||||
Username: "snorlax",
|
||||
Subject: "panda",
|
||||
},
|
||||
RequestedAudience: nil,
|
||||
GrantedAudience: nil,
|
||||
}
|
||||
err := storage.CreateOpenIDConnectSession(ctx, "fancy-code.fancy-signature", request)
|
||||
require.NoError(t, err)
|
||||
|
||||
newRequest, err := storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, request, newRequest)
|
||||
|
||||
err = storage.DeleteOpenIDConnectSession(ctx, "fancy-code.fancy-signature")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, wantActions, client.Actions())
|
||||
}
|
||||
|
||||
func TestGetNotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
_, notFoundErr := storage.GetOpenIDConnectSession(ctx, "authcode.non-existent-signature", nil)
|
||||
require.EqualError(t, notFoundErr, "not_found")
|
||||
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestWrongVersion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
secret := &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "oidc",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/oidc",
|
||||
}
|
||||
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil)
|
||||
|
||||
require.EqualError(t, err, "oidc request data has wrong version: oidc session for fancy-signature has version not-the-right-version instead of 1")
|
||||
}
|
||||
|
||||
func TestNilSessionRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
secret := &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-oidc-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "oidc",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/oidc",
|
||||
}
|
||||
|
||||
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.GetOpenIDConnectSession(ctx, "fancy-code.fancy-signature", nil)
|
||||
require.EqualError(t, err, "malformed oidc session for fancy-signature: oidc request data must be present")
|
||||
}
|
||||
|
||||
func TestCreateWithNilRequester(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", nil)
|
||||
require.EqualError(t, err, "requester must be of type fosite.Request")
|
||||
}
|
||||
|
||||
func TestCreateWithWrongRequesterDataTypes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
request := &fosite.Request{
|
||||
Session: nil,
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
}
|
||||
err := storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request)
|
||||
require.EqualError(t, err, "requester's session must be of type openid.DefaultSession")
|
||||
|
||||
request = &fosite.Request{
|
||||
Session: &openid.DefaultSession{},
|
||||
Client: nil,
|
||||
}
|
||||
err = storage.CreateOpenIDConnectSession(ctx, "authcode.signature-doesnt-matter", request)
|
||||
require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient")
|
||||
}
|
||||
|
||||
func TestAuthcodeHasNoDot(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
err := storage.CreateOpenIDConnectSession(ctx, "all-one-part", nil)
|
||||
require.EqualError(t, err, "malformed authorization code")
|
||||
}
|
98
internal/fositestorage/pkce/pkce.go
Normal file
98
internal/fositestorage/pkce/pkce.go
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package pkce
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/ory/fosite/handler/pkce"
|
||||
"k8s.io/apimachinery/pkg/api/errors"
|
||||
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
|
||||
"go.pinniped.dev/internal/constable"
|
||||
"go.pinniped.dev/internal/crud"
|
||||
"go.pinniped.dev/internal/fositestorage"
|
||||
)
|
||||
|
||||
const (
|
||||
ErrInvalidPKCERequestVersion = constable.Error("pkce request data has wrong version")
|
||||
ErrInvalidPKCERequestData = constable.Error("pkce request data must be present")
|
||||
|
||||
pkceStorageVersion = "1"
|
||||
)
|
||||
|
||||
var _ pkce.PKCERequestStorage = &pkceStorage{}
|
||||
|
||||
type pkceStorage struct {
|
||||
storage crud.Storage
|
||||
}
|
||||
|
||||
type session struct {
|
||||
Request *fosite.Request `json:"request"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
func New(secrets corev1client.SecretInterface) pkce.PKCERequestStorage {
|
||||
return &pkceStorage{storage: crud.New("pkce", secrets)}
|
||||
}
|
||||
|
||||
func (a *pkceStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error {
|
||||
request, err := fositestorage.ValidateAndExtractAuthorizeRequest(requester)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = a.storage.Create(ctx, signature, &session{Request: request, Version: pkceStorageVersion})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *pkceStorage) GetPKCERequestSession(ctx context.Context, signature string, _ fosite.Session) (fosite.Requester, error) {
|
||||
session, _, err := a.getSession(ctx, signature)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session.Request, err
|
||||
}
|
||||
|
||||
func (a *pkceStorage) DeletePKCERequestSession(ctx context.Context, signature string) error {
|
||||
return a.storage.Delete(ctx, signature)
|
||||
}
|
||||
|
||||
func (a *pkceStorage) getSession(ctx context.Context, signature string) (*session, string, error) {
|
||||
session := newValidEmptyPKCESession()
|
||||
rv, err := a.storage.Get(ctx, signature, session)
|
||||
|
||||
if errors.IsNotFound(err) {
|
||||
return nil, "", fosite.ErrNotFound.WithCause(err).WithDebug(err.Error())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to get pkce session for %s: %w", signature, err)
|
||||
}
|
||||
|
||||
if version := session.Version; version != pkceStorageVersion {
|
||||
return nil, "", fmt.Errorf("%w: pkce session for %s has version %s instead of %s",
|
||||
ErrInvalidPKCERequestVersion, signature, version, pkceStorageVersion)
|
||||
}
|
||||
|
||||
if session.Request.ID == "" {
|
||||
return nil, "", fmt.Errorf("malformed pkce session for %s: %w", signature, ErrInvalidPKCERequestData)
|
||||
}
|
||||
|
||||
return session, rv, nil
|
||||
}
|
||||
|
||||
func newValidEmptyPKCESession() *session {
|
||||
return &session{
|
||||
Request: &fosite.Request{
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
Session: &openid.DefaultSession{},
|
||||
},
|
||||
}
|
||||
}
|
199
internal/fositestorage/pkce/pkce_test.go
Normal file
199
internal/fositestorage/pkce/pkce_test.go
Normal file
@ -0,0 +1,199 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package pkce
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
coretesting "k8s.io/client-go/testing"
|
||||
)
|
||||
|
||||
const namespace = "test-ns"
|
||||
|
||||
func TestPKCEStorage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
secretsGVR := schema.GroupVersionResource{
|
||||
Group: "",
|
||||
Version: "v1",
|
||||
Resource: "secrets",
|
||||
}
|
||||
|
||||
wantActions := []coretesting.Action{
|
||||
coretesting.NewCreateAction(secretsGVR, namespace, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "pkce",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/pkce",
|
||||
}),
|
||||
coretesting.NewGetAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"),
|
||||
coretesting.NewDeleteAction(secretsGVR, namespace, "pinniped-storage-pkce-pwu5zs7lekbhnln2w4"),
|
||||
}
|
||||
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
request := &fosite.Request{
|
||||
ID: "abcd-1",
|
||||
RequestedAt: time.Time{},
|
||||
Client: &fosite.DefaultOpenIDConnectClient{
|
||||
DefaultClient: &fosite.DefaultClient{
|
||||
ID: "pinny",
|
||||
Secret: nil,
|
||||
RedirectURIs: nil,
|
||||
GrantTypes: nil,
|
||||
ResponseTypes: nil,
|
||||
Scopes: nil,
|
||||
Audience: nil,
|
||||
Public: true,
|
||||
},
|
||||
JSONWebKeysURI: "where",
|
||||
JSONWebKeys: nil,
|
||||
TokenEndpointAuthMethod: "something",
|
||||
RequestURIs: nil,
|
||||
RequestObjectSigningAlgorithm: "",
|
||||
TokenEndpointAuthSigningAlgorithm: "",
|
||||
},
|
||||
RequestedScope: nil,
|
||||
GrantedScope: nil,
|
||||
Form: url.Values{"key": []string{"val"}},
|
||||
Session: &openid.DefaultSession{
|
||||
Claims: nil,
|
||||
Headers: nil,
|
||||
ExpiresAt: nil,
|
||||
Username: "snorlax",
|
||||
Subject: "panda",
|
||||
},
|
||||
RequestedAudience: nil,
|
||||
GrantedAudience: nil,
|
||||
}
|
||||
err := storage.CreatePKCERequestSession(ctx, "fancy-signature", request)
|
||||
require.NoError(t, err)
|
||||
|
||||
newRequest, err := storage.GetPKCERequestSession(ctx, "fancy-signature", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, request, newRequest)
|
||||
|
||||
err = storage.DeletePKCERequestSession(ctx, "fancy-signature")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, wantActions, client.Actions())
|
||||
}
|
||||
|
||||
func TestGetNotFound(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
_, notFoundErr := storage.GetPKCERequestSession(ctx, "non-existent-signature", nil)
|
||||
require.EqualError(t, notFoundErr, "not_found")
|
||||
require.True(t, errors.Is(notFoundErr, fosite.ErrNotFound))
|
||||
}
|
||||
|
||||
func TestWrongVersion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
secret := &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "pkce",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"request":{"id":"abcd-1","requestedAt":"0001-01-01T00:00:00Z","client":{"id":"pinny","redirect_uris":null,"grant_types":null,"response_types":null,"scopes":null,"audience":null,"public":true,"jwks_uri":"where","jwks":null,"token_endpoint_auth_method":"something","request_uris":null,"request_object_signing_alg":"","token_endpoint_auth_signing_alg":""},"scopes":null,"grantedScopes":null,"form":{"key":["val"]},"session":{"Claims":null,"Headers":null,"ExpiresAt":null,"Username":"snorlax","Subject":"panda"},"requestedAudience":null,"grantedAudience":null},"version":"not-the-right-version"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/pkce",
|
||||
}
|
||||
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil)
|
||||
|
||||
require.EqualError(t, err, "pkce request data has wrong version: pkce session for fancy-signature has version not-the-right-version instead of 1")
|
||||
}
|
||||
|
||||
func TestNilSessionRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
secret := &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Name: "pinniped-storage-pkce-pwu5zs7lekbhnln2w4",
|
||||
ResourceVersion: "",
|
||||
Labels: map[string]string{
|
||||
"storage.pinniped.dev": "pkce",
|
||||
},
|
||||
},
|
||||
Data: map[string][]byte{
|
||||
"pinniped-storage-data": []byte(`{"nonsense-key": "nonsense-value","version":"1"}`),
|
||||
"pinniped-storage-version": []byte("1"),
|
||||
},
|
||||
Type: "storage.pinniped.dev/pkce",
|
||||
}
|
||||
|
||||
_, err := secrets.Create(ctx, secret, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = storage.GetPKCERequestSession(ctx, "fancy-signature", nil)
|
||||
require.EqualError(t, err, "malformed pkce session for fancy-signature: pkce request data must be present")
|
||||
}
|
||||
|
||||
func TestCreateWithNilRequester(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", nil)
|
||||
require.EqualError(t, err, "requester must be of type fosite.Request")
|
||||
}
|
||||
|
||||
func TestCreateWithWrongRequesterDataTypes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets(namespace)
|
||||
storage := New(secrets)
|
||||
|
||||
request := &fosite.Request{
|
||||
Session: nil,
|
||||
Client: &fosite.DefaultOpenIDConnectClient{},
|
||||
}
|
||||
err := storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request)
|
||||
require.EqualError(t, err, "requester's session must be of type openid.DefaultSession")
|
||||
|
||||
request = &fosite.Request{
|
||||
Session: &openid.DefaultSession{},
|
||||
Client: nil,
|
||||
}
|
||||
err = storage.CreatePKCERequestSession(ctx, "signature-doesnt-matter", request)
|
||||
require.EqualError(t, err, "requester's client must be of type fosite.DefaultOpenIDConnectClient")
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mockupstreamoidcidentityprovider
|
||||
|
||||
//go:generate go run -v github.com/golang/mock/mockgen -destination=mockupstreamoidcidentityprovider.go -package=mockupstreamoidcidentityprovider -copyright_file=../../../hack/header.txt go.pinniped.dev/internal/oidc/provider UpstreamOIDCIdentityProviderI
|
@ -0,0 +1,159 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: go.pinniped.dev/internal/oidc/provider (interfaces: UpstreamOIDCIdentityProviderI)
|
||||
|
||||
// Package mockupstreamoidcidentityprovider is a generated GoMock package.
|
||||
package mockupstreamoidcidentityprovider
|
||||
|
||||
import (
|
||||
context "context"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
nonce "go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
pkce "go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
oauth2 "golang.org/x/oauth2"
|
||||
url "net/url"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockUpstreamOIDCIdentityProviderI is a mock of UpstreamOIDCIdentityProviderI interface
|
||||
type MockUpstreamOIDCIdentityProviderI struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockUpstreamOIDCIdentityProviderIMockRecorder
|
||||
}
|
||||
|
||||
// MockUpstreamOIDCIdentityProviderIMockRecorder is the mock recorder for MockUpstreamOIDCIdentityProviderI
|
||||
type MockUpstreamOIDCIdentityProviderIMockRecorder struct {
|
||||
mock *MockUpstreamOIDCIdentityProviderI
|
||||
}
|
||||
|
||||
// NewMockUpstreamOIDCIdentityProviderI creates a new mock instance
|
||||
func NewMockUpstreamOIDCIdentityProviderI(ctrl *gomock.Controller) *MockUpstreamOIDCIdentityProviderI {
|
||||
mock := &MockUpstreamOIDCIdentityProviderI{ctrl: ctrl}
|
||||
mock.recorder = &MockUpstreamOIDCIdentityProviderIMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityProviderIMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ExchangeAuthcodeAndValidateTokens mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4)
|
||||
ret0, _ := ret[0].(oidctypes.Token)
|
||||
ret1, _ := ret[1].(map[string]interface{})
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3, arg4)
|
||||
}
|
||||
|
||||
// GetAuthorizationURL mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) GetAuthorizationURL() *url.URL {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthorizationURL")
|
||||
ret0, _ := ret[0].(*url.URL)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetAuthorizationURL indicates an expected call of GetAuthorizationURL
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetAuthorizationURL() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthorizationURL", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetAuthorizationURL))
|
||||
}
|
||||
|
||||
// GetClientID mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) GetClientID() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClientID")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClientID indicates an expected call of GetClientID
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetClientID() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientID", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetClientID))
|
||||
}
|
||||
|
||||
// GetGroupsClaim mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) GetGroupsClaim() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupsClaim")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetGroupsClaim indicates an expected call of GetGroupsClaim
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetGroupsClaim() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetGroupsClaim))
|
||||
}
|
||||
|
||||
// GetName mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) GetName() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetName")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetName indicates an expected call of GetName
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetName() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetName))
|
||||
}
|
||||
|
||||
// GetScopes mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) GetScopes() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetScopes")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetScopes indicates an expected call of GetScopes
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetScopes() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetScopes", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetScopes))
|
||||
}
|
||||
|
||||
// GetUsernameClaim mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) GetUsernameClaim() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetUsernameClaim")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetUsernameClaim indicates an expected call of GetUsernameClaim
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetUsernameClaim() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsernameClaim", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).GetUsernameClaim))
|
||||
}
|
||||
|
||||
// ValidateToken mocks base method
|
||||
func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(oidctypes.Token)
|
||||
ret1, _ := ret[1].(map[string]interface{})
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// ValidateToken indicates an expected call of ValidateToken
|
||||
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ValidateToken(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateToken", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ValidateToken), arg0, arg1, arg2)
|
||||
}
|
@ -9,14 +9,13 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/ory/fosite/token/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
@ -24,42 +23,15 @@ import (
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
const (
|
||||
// Just in case we need to make a breaking change to the format of the upstream state param,
|
||||
// we are including a format version number. This gives the opportunity for a future version of Pinniped
|
||||
// to have the consumer of this format decide to reject versions that it doesn't understand.
|
||||
upstreamStateParamFormatVersion = "1"
|
||||
|
||||
// The `name` passed to the encoder for encoding the upstream state param value. This name is short
|
||||
// because it will be encoded into the upstream state param value and we're trying to keep that small.
|
||||
upstreamStateParamEncodingName = "s"
|
||||
|
||||
// The name of the browser cookie which shall hold our CSRF value.
|
||||
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
|
||||
csrfCookieName = "__Host-pinniped-csrf"
|
||||
|
||||
// The `name` passed to the encoder for encoding and decoding the CSRF cookie contents.
|
||||
csrfCookieEncodingName = "csrf"
|
||||
)
|
||||
|
||||
type IDPListGetter interface {
|
||||
GetIDPList() []provider.UpstreamOIDCIdentityProvider
|
||||
}
|
||||
|
||||
// This is the encoding side of the securecookie.Codec interface.
|
||||
type Encoder interface {
|
||||
Encode(name string, value interface{}) (string, error)
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
issuer string,
|
||||
idpListGetter IDPListGetter,
|
||||
downstreamIssuer string,
|
||||
idpListGetter oidc.IDPListGetter,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
generateCSRF func() (csrftoken.CSRFToken, error),
|
||||
generatePKCE func() (pkce.Code, error),
|
||||
generateNonce func() (nonce.Nonce, error),
|
||||
upstreamStateEncoder Encoder,
|
||||
cookieCodec securecookie.Codec,
|
||||
upstreamStateEncoder oidc.Encoder,
|
||||
cookieCodec oidc.Codec,
|
||||
) http.Handler {
|
||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
||||
@ -69,11 +41,7 @@ func NewHandler(
|
||||
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method)
|
||||
}
|
||||
|
||||
csrfFromCookie, err := readCSRFCookie(r, cookieCodec)
|
||||
if err != nil {
|
||||
plog.InfoErr("error reading CSRF cookie", err)
|
||||
return err
|
||||
}
|
||||
csrfFromCookie := readCSRFCookie(r, cookieCodec)
|
||||
|
||||
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r)
|
||||
if err != nil {
|
||||
@ -116,15 +84,22 @@ func NewHandler(
|
||||
}
|
||||
|
||||
upstreamOAuthConfig := oauth2.Config{
|
||||
ClientID: upstreamIDP.ClientID,
|
||||
ClientID: upstreamIDP.GetClientID(),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: upstreamIDP.AuthorizationURL.String(),
|
||||
AuthURL: upstreamIDP.GetAuthorizationURL().String(),
|
||||
},
|
||||
RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.Name),
|
||||
Scopes: upstreamIDP.Scopes,
|
||||
RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer),
|
||||
Scopes: upstreamIDP.GetScopes(),
|
||||
}
|
||||
|
||||
encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder)
|
||||
encodedStateParamValue, err := upstreamStateParam(
|
||||
authorizeRequester,
|
||||
upstreamIDP.GetName(),
|
||||
nonceValue,
|
||||
csrfValue,
|
||||
pkceValue,
|
||||
upstreamStateEncoder,
|
||||
)
|
||||
if err != nil {
|
||||
plog.Error("authorize upstream state param error", err)
|
||||
return err
|
||||
@ -154,20 +129,23 @@ func NewHandler(
|
||||
})
|
||||
}
|
||||
|
||||
func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) {
|
||||
receivedCSRFCookie, err := r.Cookie(csrfCookieName)
|
||||
func readCSRFCookie(r *http.Request, codec oidc.Codec) csrftoken.CSRFToken {
|
||||
receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
|
||||
if err != nil {
|
||||
// Error means that the cookie was not found
|
||||
return "", nil
|
||||
return ""
|
||||
}
|
||||
|
||||
var csrfFromCookie csrftoken.CSRFToken
|
||||
err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
|
||||
err = codec.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
|
||||
if err != nil {
|
||||
return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err)
|
||||
// We can ignore any errors and just make a new cookie. Hopefully this will
|
||||
// make the user experience better if, for example, the server rotated
|
||||
// cookie signing keys and then a user submitted a very old cookie.
|
||||
return ""
|
||||
}
|
||||
|
||||
return csrfFromCookie, nil
|
||||
return csrfFromCookie
|
||||
}
|
||||
|
||||
func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
|
||||
@ -178,7 +156,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
|
||||
}
|
||||
}
|
||||
|
||||
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
|
||||
func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (provider.UpstreamOIDCIdentityProviderI, error) {
|
||||
allUpstreamIDPs := idpListGetter.GetIDPList()
|
||||
if len(allUpstreamIDPs) == 0 {
|
||||
return nil, httperr.New(
|
||||
@ -191,7 +169,7 @@ func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdent
|
||||
"Too many upstream providers are configured (support for multiple upstreams is not yet implemented)",
|
||||
)
|
||||
}
|
||||
return &allUpstreamIDPs[0], nil
|
||||
return allUpstreamIDPs[0], nil
|
||||
}
|
||||
|
||||
func generateValues(
|
||||
@ -214,48 +192,42 @@ func generateValues(
|
||||
return csrfValue, nonceValue, pkceValue, nil
|
||||
}
|
||||
|
||||
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
|
||||
type upstreamStateParamData struct {
|
||||
AuthParams string `json:"p"`
|
||||
Nonce nonce.Nonce `json:"n"`
|
||||
CSRFToken csrftoken.CSRFToken `json:"c"`
|
||||
PKCECode pkce.Code `json:"k"`
|
||||
StateParamFormatVersion string `json:"v"`
|
||||
}
|
||||
|
||||
func upstreamStateParam(
|
||||
authorizeRequester fosite.AuthorizeRequester,
|
||||
upstreamName string,
|
||||
nonceValue nonce.Nonce,
|
||||
csrfValue csrftoken.CSRFToken,
|
||||
pkceValue pkce.Code,
|
||||
encoder Encoder,
|
||||
encoder oidc.Encoder,
|
||||
) (string, error) {
|
||||
stateParamData := upstreamStateParamData{
|
||||
stateParamData := oidc.UpstreamStateParamData{
|
||||
AuthParams: authorizeRequester.GetRequestForm().Encode(),
|
||||
UpstreamName: upstreamName,
|
||||
Nonce: nonceValue,
|
||||
CSRFToken: csrfValue,
|
||||
PKCECode: pkceValue,
|
||||
StateParamFormatVersion: upstreamStateParamFormatVersion,
|
||||
FormatVersion: oidc.UpstreamStateParamFormatVersion,
|
||||
}
|
||||
encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData)
|
||||
encodedStateParamValue, err := encoder.Encode(oidc.UpstreamStateParamEncodingName, stateParamData)
|
||||
if err != nil {
|
||||
return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err)
|
||||
}
|
||||
return encodedStateParamValue, nil
|
||||
}
|
||||
|
||||
func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error {
|
||||
encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue)
|
||||
func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec oidc.Codec) error {
|
||||
encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue)
|
||||
if err != nil {
|
||||
return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: csrfCookieName,
|
||||
Name: oidc.CSRFCookieName,
|
||||
Value: encodedCSRFValue,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
Secure: true,
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
return nil
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"go.pinniped.dev/internal/here"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||
"go.pinniped.dev/internal/oidc/oidctestutil"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
@ -28,6 +29,7 @@ import (
|
||||
|
||||
func TestAuthorizationEndpoint(t *testing.T) {
|
||||
const (
|
||||
downstreamIssuer = "https://my-downstream-issuer.com/some-path"
|
||||
downstreamRedirectURI = "http://127.0.0.1/callback"
|
||||
downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback"
|
||||
)
|
||||
@ -113,21 +115,19 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
|
||||
require.NoError(t, err)
|
||||
|
||||
upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{
|
||||
upstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: "some-idp",
|
||||
ClientID: "some-client-id",
|
||||
AuthorizationURL: *upstreamAuthURL,
|
||||
Scopes: []string{"scope1", "scope2"},
|
||||
}
|
||||
|
||||
issuer := "https://my-issuer.com/some-path"
|
||||
|
||||
// Configure fosite the same way that the production code would, except use in-memory storage.
|
||||
// Configure fosite the same way that the production code would, using NullStorage to turn off storage.
|
||||
oauthStore := oidc.NullStorage{}
|
||||
hmacSecret := []byte("some secret - must have at least 32 bytes")
|
||||
var signingKeyIsUnused *ecdsa.PrivateKey
|
||||
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
|
||||
oauthHelper := oidc.FositeOauth2Helper(issuer, oauthStore, hmacSecret, signingKeyIsUnused)
|
||||
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret, signingKeyIsUnused)
|
||||
|
||||
happyCSRF := "test-csrf"
|
||||
happyPKCE := "test-pkce"
|
||||
@ -206,14 +206,19 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides))
|
||||
}
|
||||
|
||||
expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string {
|
||||
expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride, upstreamNameOverride string) string {
|
||||
csrf := happyCSRF
|
||||
if csrfValueOverride != "" {
|
||||
csrf = csrfValueOverride
|
||||
}
|
||||
upstreamName := upstreamOIDCIdentityProvider.Name
|
||||
if upstreamNameOverride != "" {
|
||||
upstreamName = upstreamNameOverride
|
||||
}
|
||||
encoded, err := happyStateEncoder.Encode("s",
|
||||
expectedUpstreamStateParamFormat{
|
||||
oidctestutil.ExpectedUpstreamStateParamFormat{
|
||||
P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)),
|
||||
U: upstreamName,
|
||||
N: happyNonce,
|
||||
C: csrf,
|
||||
K: happyPKCE,
|
||||
@ -234,7 +239,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
"nonce": happyNonce,
|
||||
"code_challenge": expectedUpstreamCodeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"redirect_uri": issuer + "/callback/some-idp",
|
||||
"redirect_uri": downstreamIssuer + "/callback",
|
||||
})
|
||||
}
|
||||
|
||||
@ -250,8 +255,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
generateCSRF func() (csrftoken.CSRFToken, error)
|
||||
generatePKCE func() (pkce.Code, error)
|
||||
generateNonce func() (nonce.Nonce, error)
|
||||
stateEncoder securecookie.Codec
|
||||
cookieEncoder securecookie.Codec
|
||||
stateEncoder oidc.Codec
|
||||
cookieEncoder oidc.Codec
|
||||
method string
|
||||
path string
|
||||
contentType string
|
||||
@ -271,8 +276,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "happy path using GET without a CSRF cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -283,14 +288,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
wantStatus: http.StatusFound,
|
||||
wantContentType: "text/html; charset=utf-8",
|
||||
wantCSRFValueInCookieHeader: happyCSRF,
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")),
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")),
|
||||
wantUpstreamStateParamInLocationHeader: true,
|
||||
wantBodyStringWithLocationInHref: true,
|
||||
},
|
||||
{
|
||||
name: "happy path using GET with a CSRF cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -298,17 +303,17 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
cookieEncoder: happyCookieEncoder,
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue,
|
||||
csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + " ",
|
||||
wantStatus: http.StatusFound,
|
||||
wantContentType: "text/html; charset=utf-8",
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)),
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")),
|
||||
wantUpstreamStateParamInLocationHeader: true,
|
||||
wantBodyStringWithLocationInHref: true,
|
||||
},
|
||||
{
|
||||
name: "happy path using POST",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -322,13 +327,33 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
wantContentType: "",
|
||||
wantBodyString: "",
|
||||
wantCSRFValueInCookieHeader: happyCSRF,
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")),
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")),
|
||||
wantUpstreamStateParamInLocationHeader: true,
|
||||
},
|
||||
{
|
||||
name: "error while decoding CSRF cookie just generates a new cookie and succeeds as usual",
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
stateEncoder: happyStateEncoder,
|
||||
cookieEncoder: happyCookieEncoder,
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
|
||||
wantStatus: http.StatusFound,
|
||||
wantContentType: "text/html; charset=utf-8",
|
||||
// Generated a new CSRF cookie and set it in the response.
|
||||
wantCSRFValueInCookieHeader: happyCSRF,
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")),
|
||||
wantUpstreamStateParamInLocationHeader: true,
|
||||
wantBodyStringWithLocationInHref: true,
|
||||
},
|
||||
{
|
||||
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -343,14 +368,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
wantCSRFValueInCookieHeader: happyCSRF,
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{
|
||||
"redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
|
||||
}, "")),
|
||||
}, "", "")),
|
||||
wantUpstreamStateParamInLocationHeader: true,
|
||||
wantBodyStringWithLocationInHref: true,
|
||||
},
|
||||
{
|
||||
name: "downstream redirect uri does not match what is configured for client",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -366,8 +391,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "downstream client does not exist",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -381,8 +406,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "response type is unsupported",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -397,8 +422,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "downstream scopes do not match what is configured for client",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -413,8 +438,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "missing response type in request",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -429,8 +454,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "missing client id in request",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -444,8 +469,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -460,8 +485,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -476,8 +501,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -492,8 +517,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -510,8 +535,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
// This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running
|
||||
// through that part of the fosite library.
|
||||
name: "prompt param is not allowed to have none and another legal value at the same time",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -526,8 +551,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "OIDC validations are skipped when the openid scope was not requested",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -540,15 +565,15 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
wantContentType: "text/html; charset=utf-8",
|
||||
wantCSRFValueInCookieHeader: happyCSRF,
|
||||
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(
|
||||
map[string]string{"prompt": "none login", "scope": "email"}, "",
|
||||
map[string]string{"prompt": "none login", "scope": "email"}, "", "",
|
||||
)),
|
||||
wantUpstreamStateParamInLocationHeader: true,
|
||||
wantBodyStringWithLocationInHref: true,
|
||||
},
|
||||
{
|
||||
name: "state does not have enough entropy",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -563,8 +588,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error while encoding upstream state param",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -578,8 +603,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error while encoding CSRF cookie value for new cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -593,8 +618,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error while generating CSRF token",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -608,8 +633,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error while generating nonce",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
|
||||
@ -623,8 +648,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "error while generating PKCE",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -636,26 +661,10 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
wantContentType: "text/plain; charset=utf-8",
|
||||
wantBodyString: "Internal Server Error: error generating PKCE param\n",
|
||||
},
|
||||
{
|
||||
name: "error while decoding CSRF cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
stateEncoder: happyStateEncoder,
|
||||
cookieEncoder: happyCookieEncoder,
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantContentType: "text/plain; charset=utf-8",
|
||||
wantBodyString: "Unprocessable Entity: error reading CSRF cookie\n",
|
||||
},
|
||||
{
|
||||
name: "no upstream providers are configured",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(), // empty
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(), // empty
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
@ -664,8 +673,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "too many upstream providers are configured",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
@ -674,8 +683,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "PUT is a bad method",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
method: http.MethodPut,
|
||||
path: "/some/path",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
@ -684,8 +693,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "PATCH is a bad method",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
method: http.MethodPatch,
|
||||
path: "/some/path",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
@ -694,8 +703,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "DELETE is a bad method",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
issuer: downstreamIssuer,
|
||||
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||
method: http.MethodDelete,
|
||||
path: "/some/path",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
@ -712,6 +721,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
}
|
||||
rsp := httptest.NewRecorder()
|
||||
subject.ServeHTTP(rsp, req)
|
||||
t.Logf("response: %#v", rsp)
|
||||
t.Logf("response body: %q", rsp.Body.String())
|
||||
|
||||
require.Equal(t, test.wantStatus, rsp.Code)
|
||||
requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
|
||||
@ -742,7 +753,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
if test.wantCSRFValueInCookieHeader != "" {
|
||||
require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
|
||||
actualCookie := rsp.Header().Get("Set-Cookie")
|
||||
regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); HttpOnly; Secure; SameSite=Strict")
|
||||
regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); Path=/; HttpOnly; Secure; SameSite=Strict")
|
||||
submatches := regex.FindStringSubmatch(actualCookie)
|
||||
require.Len(t, submatches, 2)
|
||||
captured := submatches[1]
|
||||
@ -772,13 +783,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
runOneTestCase(t, test, subject)
|
||||
|
||||
// Call the setter to change the upstream IDP settings.
|
||||
newProviderSettings := provider.UpstreamOIDCIdentityProvider{
|
||||
newProviderSettings := oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: "some-other-idp",
|
||||
ClientID: "some-other-client-id",
|
||||
AuthorizationURL: *upstreamAuthURL,
|
||||
Scopes: []string{"other-scope1", "other-scope2"},
|
||||
}
|
||||
test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{newProviderSettings})
|
||||
test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(&newProviderSettings)})
|
||||
|
||||
// Update the expectations of the test case to match the new upstream IDP settings.
|
||||
test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(),
|
||||
@ -787,11 +798,11 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
"access_type": "offline",
|
||||
"scope": "other-scope1 other-scope2",
|
||||
"client_id": "some-other-client-id",
|
||||
"state": expectedUpstreamStateParam(nil, ""),
|
||||
"state": expectedUpstreamStateParam(nil, "", newProviderSettings.Name),
|
||||
"nonce": happyNonce,
|
||||
"code_challenge": expectedUpstreamCodeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"redirect_uri": issuer + "/callback/some-other-idp",
|
||||
"redirect_uri": downstreamIssuer + "/callback",
|
||||
},
|
||||
)
|
||||
test.wantBodyString = fmt.Sprintf(`<a href="%s">Found</a>.%s`,
|
||||
@ -807,20 +818,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// Declare a separate type from the production code to ensure that the state param's contents was serialized
|
||||
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
|
||||
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
|
||||
// assertions about the redirect URL in this test.
|
||||
type expectedUpstreamStateParamFormat struct {
|
||||
P string `json:"p"`
|
||||
N string `json:"n"`
|
||||
C string `json:"c"`
|
||||
K string `json:"k"`
|
||||
V string `json:"v"`
|
||||
}
|
||||
|
||||
type errorReturningEncoder struct {
|
||||
securecookie.Codec
|
||||
oidc.Codec
|
||||
}
|
||||
|
||||
func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) {
|
||||
@ -843,7 +842,7 @@ func requireEqualContentType(t *testing.T, actual string, expected string) {
|
||||
require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
|
||||
}
|
||||
|
||||
func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) {
|
||||
func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) {
|
||||
t.Helper()
|
||||
actualLocationURL, err := url.Parse(actualURL)
|
||||
require.NoError(t, err)
|
||||
@ -852,13 +851,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL
|
||||
|
||||
expectedQueryStateParam := expectedLocationURL.Query().Get("state")
|
||||
require.NotEmpty(t, expectedQueryStateParam)
|
||||
var expectedDecodedStateParam expectedUpstreamStateParamFormat
|
||||
var expectedDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat
|
||||
err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
|
||||
require.NoError(t, err)
|
||||
|
||||
actualQueryStateParam := actualLocationURL.Query().Get("state")
|
||||
require.NotEmpty(t, actualQueryStateParam)
|
||||
var actualDecodedStateParam expectedUpstreamStateParamFormat
|
||||
var actualDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat
|
||||
err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -871,10 +870,20 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore
|
||||
require.NoError(t, err)
|
||||
expectedLocationURL, err := url.Parse(expectedURL)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme)
|
||||
require.Equal(t, expectedLocationURL.User, actualLocationURL.User)
|
||||
require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host)
|
||||
require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path)
|
||||
require.Equal(t, expectedLocationURL.Scheme, actualLocationURL.Scheme,
|
||||
"schemes were not equal: expected %s but got %s", expectedURL, actualURL,
|
||||
)
|
||||
require.Equal(t, expectedLocationURL.User, actualLocationURL.User,
|
||||
"users were not equal: expected %s but got %s", expectedURL, actualURL,
|
||||
)
|
||||
|
||||
require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host,
|
||||
"hosts were not equal: expected %s but got %s", expectedURL, actualURL,
|
||||
)
|
||||
|
||||
require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path,
|
||||
"paths were not equal: expected %s but got %s", expectedURL, actualURL,
|
||||
)
|
||||
|
||||
expectedLocationQuery := expectedLocationURL.Query()
|
||||
actualLocationQuery := actualLocationURL.Query()
|
||||
@ -886,9 +895,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore
|
||||
}
|
||||
require.Equal(t, expectedLocationQuery, actualLocationQuery)
|
||||
}
|
||||
|
||||
func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
|
||||
idpProvider := provider.NewDynamicUpstreamIDPProvider()
|
||||
idpProvider.SetIDPList(upstreamOIDCIdentityProviders)
|
||||
return idpProvider
|
||||
}
|
||||
|
306
internal/oidc/callback/callback_handler.go
Normal file
306
internal/oidc/callback/callback_handler.go
Normal file
@ -0,0 +1,306 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package callback provides a handler for the OIDC callback endpoint.
|
||||
package callback
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/ory/fosite/token/jwt"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
const (
|
||||
// The name of the issuer claim specified in the OIDC spec.
|
||||
idTokenIssuerClaim = "iss"
|
||||
|
||||
// The name of the subject claim specified in the OIDC spec.
|
||||
idTokenSubjectClaim = "sub"
|
||||
|
||||
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
|
||||
// ID token if the upstream OIDC IDP did not tell us to use another claim.
|
||||
defaultUpstreamUsernameClaim = idTokenSubjectClaim
|
||||
|
||||
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
|
||||
// information.
|
||||
downstreamGroupsClaim = "groups"
|
||||
)
|
||||
|
||||
func NewHandler(
|
||||
idpListGetter oidc.IDPListGetter,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
stateDecoder, cookieDecoder oidc.Decoder,
|
||||
redirectURI string,
|
||||
) http.Handler {
|
||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
state, err := validateRequest(r, stateDecoder, cookieDecoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter)
|
||||
if upstreamIDPConfig == nil {
|
||||
plog.Warning("upstream provider not found")
|
||||
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
|
||||
}
|
||||
|
||||
downstreamAuthParams, err := url.ParseQuery(state.AuthParams)
|
||||
if err != nil {
|
||||
plog.Error("error reading state downstream auth params", err)
|
||||
return httperr.New(http.StatusBadRequest, "error reading state downstream auth params")
|
||||
}
|
||||
|
||||
// Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest().
|
||||
reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams}
|
||||
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest)
|
||||
if err != nil {
|
||||
plog.Error("error using state downstream auth params", err)
|
||||
return httperr.New(http.StatusBadRequest, "error using state downstream auth params")
|
||||
}
|
||||
|
||||
// Grant the openid scope only if it was requested.
|
||||
grantOpenIDScopeIfRequested(authorizeRequester)
|
||||
|
||||
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
|
||||
r.Context(),
|
||||
authcode(r),
|
||||
state.PKCECode,
|
||||
state.Nonce,
|
||||
redirectURI,
|
||||
)
|
||||
if err != nil {
|
||||
plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName())
|
||||
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
|
||||
}
|
||||
|
||||
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
openIDSession := makeDownstreamSession(username, groups)
|
||||
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
|
||||
if err != nil {
|
||||
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
|
||||
return httperr.Wrap(http.StatusInternalServerError, "error while generating and saving authcode", err)
|
||||
}
|
||||
|
||||
oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func authcode(r *http.Request) string {
|
||||
return r.FormValue("code")
|
||||
}
|
||||
|
||||
func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
|
||||
if r.Method != http.MethodGet {
|
||||
return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
|
||||
}
|
||||
|
||||
csrfValue, err := readCSRFCookie(r, cookieDecoder)
|
||||
if err != nil {
|
||||
plog.InfoErr("error reading CSRF cookie", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if authcode(r) == "" {
|
||||
plog.Info("code param not found")
|
||||
return nil, httperr.New(http.StatusBadRequest, "code param not found")
|
||||
}
|
||||
|
||||
if r.FormValue("state") == "" {
|
||||
plog.Info("state param not found")
|
||||
return nil, httperr.New(http.StatusBadRequest, "state param not found")
|
||||
}
|
||||
|
||||
state, err := readState(r, stateDecoder)
|
||||
if err != nil {
|
||||
plog.InfoErr("error reading state", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
|
||||
plog.InfoErr("CSRF value does not match", err)
|
||||
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
|
||||
for _, p := range idpListGetter.GetIDPList() {
|
||||
if p.GetName() == upstreamName {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) {
|
||||
receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
|
||||
if err != nil {
|
||||
// Error means that the cookie was not found
|
||||
return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err)
|
||||
}
|
||||
|
||||
var csrfFromCookie csrftoken.CSRFToken
|
||||
err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
|
||||
if err != nil {
|
||||
return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err)
|
||||
}
|
||||
|
||||
return csrfFromCookie, nil
|
||||
}
|
||||
|
||||
func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
|
||||
var state oidc.UpstreamStateParamData
|
||||
if err := stateDecoder.Decode(
|
||||
oidc.UpstreamStateParamEncodingName,
|
||||
r.FormValue("state"),
|
||||
&state,
|
||||
); err != nil {
|
||||
return nil, httperr.New(http.StatusBadRequest, "error reading state")
|
||||
}
|
||||
|
||||
if state.FormatVersion != oidc.UpstreamStateParamFormatVersion {
|
||||
return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid")
|
||||
}
|
||||
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
|
||||
for _, scope := range authorizeRequester.GetRequestedScopes() {
|
||||
if scope == "openid" {
|
||||
authorizeRequester.GrantScope(scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getUsernameFromUpstreamIDToken(
|
||||
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
|
||||
idTokenClaims map[string]interface{},
|
||||
) (string, error) {
|
||||
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
||||
|
||||
user := ""
|
||||
if usernameClaim == "" {
|
||||
// The spec says the "sub" claim is only unique per issuer, so by default when there is
|
||||
// no specific username claim configured we will prepend the issuer string to make it globally unique.
|
||||
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
|
||||
if upstreamIssuer == "" {
|
||||
plog.Warning(
|
||||
"issuer claim in upstream ID token missing",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"issClaim", upstreamIssuer,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
|
||||
}
|
||||
upstreamIssuerAsString, ok := upstreamIssuer.(string)
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"issuer claim in upstream ID token has invalid format",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"issClaim", upstreamIssuer,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
||||
}
|
||||
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
|
||||
usernameClaim = defaultUpstreamUsernameClaim
|
||||
}
|
||||
|
||||
usernameAsInterface, ok := idTokenClaims[usernameClaim]
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"no username claim in upstream ID token",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
||||
"usernameClaim", usernameClaim,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
|
||||
}
|
||||
|
||||
username, ok := usernameAsInterface.(string)
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"username claim in upstream ID token has invalid format",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
|
||||
"usernameClaim", usernameClaim,
|
||||
)
|
||||
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s%s", user, username), nil
|
||||
}
|
||||
|
||||
func getGroupsFromUpstreamIDToken(
|
||||
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
|
||||
idTokenClaims map[string]interface{},
|
||||
) ([]string, error) {
|
||||
groupsClaim := upstreamIDPConfig.GetGroupsClaim()
|
||||
if groupsClaim == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
groupsAsInterface, ok := idTokenClaims[groupsClaim]
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"no groups claim in upstream ID token",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(),
|
||||
"groupsClaim", groupsClaim,
|
||||
)
|
||||
return nil, httperr.New(http.StatusUnprocessableEntity, "no groups claim in upstream ID token")
|
||||
}
|
||||
|
||||
groups, ok := groupsAsInterface.([]string)
|
||||
if !ok {
|
||||
plog.Warning(
|
||||
"groups claim in upstream ID token has invalid format",
|
||||
"upstreamName", upstreamIDPConfig.GetName(),
|
||||
"configuredGroupsClaim", upstreamIDPConfig.GetGroupsClaim(),
|
||||
"groupsClaim", groupsClaim,
|
||||
)
|
||||
return nil, httperr.New(http.StatusUnprocessableEntity, "groups claim in upstream ID token has invalid format")
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func makeDownstreamSession(username string, groups []string) *openid.DefaultSession {
|
||||
now := time.Now().UTC()
|
||||
openIDSession := &openid.DefaultSession{
|
||||
Claims: &jwt.IDTokenClaims{
|
||||
Subject: username,
|
||||
RequestedAt: now,
|
||||
AuthTime: now,
|
||||
},
|
||||
}
|
||||
if groups != nil {
|
||||
openIDSession.Claims.Extra = map[string]interface{}{
|
||||
downstreamGroupsClaim: groups,
|
||||
}
|
||||
}
|
||||
return openIDSession
|
||||
}
|
859
internal/oidc/callback/callback_handler_test.go
Normal file
859
internal/oidc/callback/callback_handler_test.go
Normal file
@ -0,0 +1,859 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package callback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
"github.com/stretchr/testify/require"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
kubetesting "k8s.io/client-go/testing"
|
||||
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/oidctestutil"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
const (
|
||||
happyUpstreamIDPName = "upstream-idp-name"
|
||||
|
||||
upstreamIssuer = "https://my-upstream-issuer.com"
|
||||
upstreamSubject = "abc123-some-guid"
|
||||
upstreamUsername = "test-pinniped-username"
|
||||
|
||||
upstreamUsernameClaim = "the-user-claim"
|
||||
upstreamGroupsClaim = "the-groups-claim"
|
||||
|
||||
happyUpstreamAuthcode = "upstream-auth-code"
|
||||
|
||||
happyUpstreamRedirectURI = "https://example.com/callback"
|
||||
|
||||
happyDownstreamState = "some-downstream-state-with-at-least-32-bytes"
|
||||
happyDownstreamCSRF = "test-csrf"
|
||||
happyDownstreamPKCE = "test-pkce"
|
||||
happyDownstreamNonce = "test-nonce"
|
||||
happyDownstreamStateVersion = "1"
|
||||
|
||||
downstreamIssuer = "https://my-downstream-issuer.com/path"
|
||||
downstreamRedirectURI = "http://127.0.0.1/callback"
|
||||
downstreamClientID = "pinniped-cli"
|
||||
downstreamNonce = "some-nonce-value"
|
||||
downstreamPKCEChallenge = "some-challenge"
|
||||
downstreamPKCEChallengeMethod = "S256"
|
||||
|
||||
timeComparisonFudgeFactor = time.Second * 15
|
||||
)
|
||||
|
||||
var (
|
||||
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
|
||||
happyDownstreamScopesRequested = []string{"openid", "profile", "email"}
|
||||
|
||||
happyDownstreamRequestParamsQuery = url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{strings.Join(happyDownstreamScopesRequested, " ")},
|
||||
"client_id": []string{downstreamClientID},
|
||||
"state": []string{happyDownstreamState},
|
||||
"nonce": []string{downstreamNonce},
|
||||
"code_challenge": []string{downstreamPKCEChallenge},
|
||||
"code_challenge_method": []string{downstreamPKCEChallengeMethod},
|
||||
"redirect_uri": []string{downstreamRedirectURI},
|
||||
}
|
||||
happyDownstreamRequestParams = happyDownstreamRequestParamsQuery.Encode()
|
||||
)
|
||||
|
||||
func TestCallbackEndpoint(t *testing.T) {
|
||||
otherUpstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: "other-upstream-idp-name",
|
||||
ClientID: "other-some-client-id",
|
||||
Scopes: []string{"other-scope1", "other-scope2"},
|
||||
}
|
||||
|
||||
var stateEncoderHashKey = []byte("fake-hash-secret")
|
||||
var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
|
||||
var cookieEncoderHashKey = []byte("fake-hash-secret2")
|
||||
var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES
|
||||
require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey)
|
||||
require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey)
|
||||
|
||||
var happyStateCodec = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey)
|
||||
happyStateCodec.SetSerializer(securecookie.JSONEncoder{})
|
||||
var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
|
||||
happyCookieCodec.SetSerializer(securecookie.JSONEncoder{})
|
||||
|
||||
happyState := happyUpstreamStateParam().Build(t, happyStateCodec)
|
||||
|
||||
encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF)
|
||||
require.NoError(t, err)
|
||||
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
|
||||
|
||||
happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{
|
||||
Authcode: happyUpstreamAuthcode,
|
||||
PKCECodeVerifier: pkce.Code(happyDownstreamPKCE),
|
||||
ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce),
|
||||
RedirectURI: happyUpstreamRedirectURI,
|
||||
}
|
||||
|
||||
// Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it
|
||||
happyDownstreamRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
idp oidctestutil.TestUpstreamOIDCIdentityProvider
|
||||
method string
|
||||
path string
|
||||
csrfCookie string
|
||||
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantRedirectLocationRegexp string
|
||||
wantGrantedOpenidScope bool
|
||||
wantDownstreamIDTokenSubject string
|
||||
wantDownstreamIDTokenGroups []string
|
||||
wantDownstreamRequestedScopes []string
|
||||
wantDownstreamNonce string
|
||||
wantDownstreamPKCEChallenge string
|
||||
wantDownstreamPKCEChallengeMethod string
|
||||
|
||||
wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs
|
||||
}{
|
||||
{
|
||||
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||
wantGrantedOpenidScope: true,
|
||||
wantBody: "",
|
||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||
wantDownstreamNonce: downstreamNonce,
|
||||
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
|
||||
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups",
|
||||
idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||
wantGrantedOpenidScope: true,
|
||||
wantBody: "",
|
||||
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
|
||||
wantDownstreamIDTokenGroups: nil,
|
||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||
wantDownstreamNonce: downstreamNonce,
|
||||
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
|
||||
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for",
|
||||
idp: happyUpstream().WithUsernameClaim("sub").Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp,
|
||||
wantGrantedOpenidScope: true,
|
||||
wantBody: "",
|
||||
wantDownstreamIDTokenSubject: upstreamSubject,
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
|
||||
wantDownstreamNonce: downstreamNonce,
|
||||
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
|
||||
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
|
||||
// Pre-upstream-exchange verification
|
||||
{
|
||||
name: "PUT method is invalid",
|
||||
method: http.MethodPut,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: PUT (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "POST method is invalid",
|
||||
method: http.MethodPost,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: POST (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "PATCH method is invalid",
|
||||
method: http.MethodPatch,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: PATCH (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "DELETE method is invalid",
|
||||
method: http.MethodDelete,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: DELETE (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "code param was not included on request",
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).WithoutCode().String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: code param not found\n",
|
||||
},
|
||||
{
|
||||
name: "state param was not included on request",
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithoutState().String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: state param not found\n",
|
||||
},
|
||||
{
|
||||
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState("this-will-not-decode").String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: error reading state\n",
|
||||
},
|
||||
{
|
||||
// This shouldn't happen in practice because the authorize endpoint should have already run the same
|
||||
// validations, but we would like to test the error handling in this endpoint anyway.
|
||||
name: "state param contains authorization request params which fail validation",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(
|
||||
happyUpstreamStateParam().
|
||||
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()).
|
||||
Build(t, happyStateCodec),
|
||||
).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
wantStatus: http.StatusInternalServerError,
|
||||
wantBody: "Internal Server Error: error while generating and saving authcode\n",
|
||||
},
|
||||
{
|
||||
name: "state's internal version does not match what we want",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyUpstreamStateParam().WithStateVersion("wrong-state-version").Build(t, happyStateCodec)).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: state format version is invalid\n",
|
||||
},
|
||||
{
|
||||
name: "state's downstream auth params element is invalid",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyUpstreamStateParam().
|
||||
WithAuthorizeRequestParams("the following is an invalid url encoding token, and therefore this is an invalid param: %z").
|
||||
Build(t, happyStateCodec)).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: error reading state downstream auth params\n",
|
||||
},
|
||||
{
|
||||
name: "state's downstream auth params are missing required value (e.g., client_id)",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(
|
||||
happyUpstreamStateParam().
|
||||
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"client_id": ""}).Encode()).
|
||||
Build(t, happyStateCodec),
|
||||
).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: error using state downstream auth params\n",
|
||||
},
|
||||
{
|
||||
name: "state's downstream auth params does not contain openid scope",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().
|
||||
WithState(
|
||||
happyUpstreamStateParam().
|
||||
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()).
|
||||
Build(t, happyStateCodec),
|
||||
).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
|
||||
wantDownstreamIDTokenSubject: upstreamUsername,
|
||||
wantDownstreamRequestedScopes: []string{"profile", "email"},
|
||||
wantDownstreamIDTokenGroups: upstreamGroupMembership,
|
||||
wantDownstreamNonce: downstreamNonce,
|
||||
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
|
||||
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "the UpstreamOIDCProvider CRD has been deleted",
|
||||
idp: otherUpstreamOIDCIdentityProvider,
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: upstream provider not found\n",
|
||||
},
|
||||
{
|
||||
name: "the CSRF cookie does not exist on request",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantBody: "Forbidden: CSRF cookie is missing\n",
|
||||
},
|
||||
{
|
||||
name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantBody: "Forbidden: error reading CSRF cookie\n",
|
||||
},
|
||||
{
|
||||
name: "cookie csrf value does not match state csrf value",
|
||||
idp: happyUpstream().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusForbidden,
|
||||
wantBody: "Forbidden: CSRF value does not match\n",
|
||||
},
|
||||
|
||||
// Upstream exchange
|
||||
{
|
||||
name: "upstream auth code exchange fails",
|
||||
idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusBadGateway,
|
||||
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream ID token does not contain requested username claim",
|
||||
idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: no username claim in upstream ID token\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream ID token does not contain requested groups claim",
|
||||
idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: no groups claim in upstream ID token\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream ID token contains username claim with weird format",
|
||||
idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream ID token does not contain iss claim when using default username claim config",
|
||||
idp: happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream ID token has an non-string iss claim when using default username claim config",
|
||||
idp: happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
{
|
||||
name: "upstream ID token contains groups claim with weird format",
|
||||
idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(),
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState(happyState).String(),
|
||||
csrfCookie: happyCSRFCookie,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n",
|
||||
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
client := fake.NewSimpleClientset()
|
||||
secrets := client.CoreV1().Secrets("some-namespace")
|
||||
|
||||
// Configure fosite the same way that the production code would.
|
||||
// Inject this into our test subject at the last second so we get a fresh storage for every test.
|
||||
oauthStore := oidc.NewKubeStorage(secrets)
|
||||
hmacSecret := []byte("some secret - must have at least 32 bytes")
|
||||
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
|
||||
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret)
|
||||
|
||||
idpListGetter := oidctestutil.NewIDPListGetter(&test.idp)
|
||||
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI)
|
||||
req := httptest.NewRequest(test.method, test.path, nil)
|
||||
if test.csrfCookie != "" {
|
||||
req.Header.Set("Cookie", test.csrfCookie)
|
||||
}
|
||||
rsp := httptest.NewRecorder()
|
||||
subject.ServeHTTP(rsp, req)
|
||||
t.Logf("response: %#v", rsp)
|
||||
t.Logf("response body: %q", rsp.Body.String())
|
||||
|
||||
if test.wantExchangeAndValidateTokensCall != nil {
|
||||
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
|
||||
test.wantExchangeAndValidateTokensCall.Ctx = req.Context()
|
||||
require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0))
|
||||
} else {
|
||||
require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
|
||||
}
|
||||
|
||||
require.Equal(t, test.wantStatus, rsp.Code)
|
||||
|
||||
if test.wantBody != "" {
|
||||
require.Equal(t, test.wantBody, rsp.Body.String())
|
||||
} else {
|
||||
require.Empty(t, rsp.Body.String())
|
||||
}
|
||||
|
||||
if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test
|
||||
// Assert that Location header matches regular expression.
|
||||
require.Len(t, rsp.Header().Values("Location"), 1)
|
||||
actualLocation := rsp.Header().Get("Location")
|
||||
regex := regexp.MustCompile(test.wantRedirectLocationRegexp)
|
||||
submatches := regex.FindStringSubmatch(actualLocation)
|
||||
require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation)
|
||||
capturedAuthCode := submatches[1]
|
||||
|
||||
// fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface
|
||||
authcodeDataAndSignature := strings.Split(capturedAuthCode, ".")
|
||||
require.Len(t, authcodeDataAndSignature, 2)
|
||||
|
||||
// Several Secrets should have been created
|
||||
expectedNumberOfCreatedSecrets := 2
|
||||
if test.wantGrantedOpenidScope {
|
||||
expectedNumberOfCreatedSecrets++
|
||||
}
|
||||
require.Len(t, client.Actions(), expectedNumberOfCreatedSecrets)
|
||||
|
||||
actualSecretNames := []string{}
|
||||
for i := range client.Actions() {
|
||||
actualAction := client.Actions()[i].(kubetesting.CreateActionImpl)
|
||||
require.Equal(t, "create", actualAction.GetVerb())
|
||||
require.Equal(t, schema.GroupVersionResource{Group: "", Version: "v1", Resource: "secrets"}, actualAction.GetResource())
|
||||
actualSecret := actualAction.GetObject().(*corev1.Secret)
|
||||
require.Empty(t, actualSecret.Namespace) // because the secrets client is already scoped to a namespace
|
||||
actualSecretNames = append(actualSecretNames, actualSecret.Name)
|
||||
}
|
||||
|
||||
// One authcode should have been stored.
|
||||
requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-authcode-")
|
||||
|
||||
storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage(
|
||||
t,
|
||||
oauthStore,
|
||||
authcodeDataAndSignature[1], // Authcode store key is authcode signature
|
||||
test.wantGrantedOpenidScope,
|
||||
test.wantDownstreamIDTokenSubject,
|
||||
test.wantDownstreamIDTokenGroups,
|
||||
test.wantDownstreamRequestedScopes,
|
||||
)
|
||||
|
||||
// One PKCE should have been stored.
|
||||
requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-pkce-")
|
||||
|
||||
validatePKCEStorage(
|
||||
t,
|
||||
oauthStore,
|
||||
authcodeDataAndSignature[1], // PKCE store key is authcode signature
|
||||
storedRequestFromAuthcode,
|
||||
storedSessionFromAuthcode,
|
||||
test.wantDownstreamPKCEChallenge,
|
||||
test.wantDownstreamPKCEChallengeMethod,
|
||||
)
|
||||
|
||||
// One IDSession should have been stored, if the downstream actually requested the "openid" scope
|
||||
if test.wantGrantedOpenidScope {
|
||||
requireAnyStringHasPrefix(t, actualSecretNames, "pinniped-storage-oidc")
|
||||
|
||||
validateIDSessionStorage(
|
||||
t,
|
||||
oauthStore,
|
||||
capturedAuthCode, // IDSession store key is full authcode
|
||||
storedRequestFromAuthcode,
|
||||
storedSessionFromAuthcode,
|
||||
test.wantDownstreamNonce,
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type requestPath struct {
|
||||
code, state *string
|
||||
}
|
||||
|
||||
func newRequestPath() *requestPath {
|
||||
c := happyUpstreamAuthcode
|
||||
s := "4321"
|
||||
return &requestPath{
|
||||
code: &c,
|
||||
state: &s,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *requestPath) WithCode(code string) *requestPath {
|
||||
r.code = &code
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithoutCode() *requestPath {
|
||||
r.code = nil
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithState(state string) *requestPath {
|
||||
r.state = &state
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithoutState() *requestPath {
|
||||
r.state = nil
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) String() string {
|
||||
path := "/downstream-provider-name/callback?"
|
||||
params := url.Values{}
|
||||
if r.code != nil {
|
||||
params.Add("code", *r.code)
|
||||
}
|
||||
if r.state != nil {
|
||||
params.Add("state", *r.state)
|
||||
}
|
||||
return path + params.Encode()
|
||||
}
|
||||
|
||||
type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
|
||||
|
||||
func happyUpstreamStateParam() *upstreamStateParamBuilder {
|
||||
return &upstreamStateParamBuilder{
|
||||
U: happyUpstreamIDPName,
|
||||
P: happyDownstreamRequestParams,
|
||||
N: happyDownstreamNonce,
|
||||
C: happyDownstreamCSRF,
|
||||
K: happyDownstreamPKCE,
|
||||
V: happyDownstreamStateVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string {
|
||||
state, err := stateEncoder.Encode("s", b)
|
||||
require.NoError(t, err)
|
||||
return state
|
||||
}
|
||||
|
||||
func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder {
|
||||
b.P = params
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder {
|
||||
b.N = nonce
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder {
|
||||
b.C = csrf
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder {
|
||||
b.K = pkce
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder {
|
||||
b.V = version
|
||||
return b
|
||||
}
|
||||
|
||||
type upstreamOIDCIdentityProviderBuilder struct {
|
||||
idToken map[string]interface{}
|
||||
usernameClaim, groupsClaim string
|
||||
authcodeExchangeErr error
|
||||
}
|
||||
|
||||
func happyUpstream() *upstreamOIDCIdentityProviderBuilder {
|
||||
return &upstreamOIDCIdentityProviderBuilder{
|
||||
usernameClaim: upstreamUsernameClaim,
|
||||
groupsClaim: upstreamGroupsClaim,
|
||||
idToken: map[string]interface{}{
|
||||
"iss": upstreamIssuer,
|
||||
"sub": upstreamSubject,
|
||||
upstreamUsernameClaim: upstreamUsername,
|
||||
upstreamGroupsClaim: upstreamGroupMembership,
|
||||
"other-claim": "should be ignored",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(claim string) *upstreamOIDCIdentityProviderBuilder {
|
||||
u.usernameClaim = claim
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder {
|
||||
u.usernameClaim = ""
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder {
|
||||
u.groupsClaim = ""
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *upstreamOIDCIdentityProviderBuilder {
|
||||
u.idToken[name] = value
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder {
|
||||
delete(u.idToken, claim)
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder {
|
||||
u.authcodeExchangeErr = err
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamOIDCIdentityProvider {
|
||||
return oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: happyUpstreamIDPName,
|
||||
ClientID: "some-client-id",
|
||||
UsernameClaim: u.usernameClaim,
|
||||
GroupsClaim: u.groupsClaim,
|
||||
Scopes: []string{"scope1", "scope2"},
|
||||
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
||||
return oidctypes.Token{}, u.idToken, u.authcodeExchangeErr
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values {
|
||||
copied := url.Values{}
|
||||
for key, value := range query {
|
||||
copied[key] = value
|
||||
}
|
||||
for key, value := range modifications {
|
||||
if value == "" {
|
||||
copied.Del(key)
|
||||
} else {
|
||||
copied[key] = []string{value}
|
||||
}
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
func validateAuthcodeStorage(
|
||||
t *testing.T,
|
||||
oauthStore *oidc.KubeStorage,
|
||||
storeKey string,
|
||||
wantGrantedOpenidScope bool,
|
||||
wantDownstreamIDTokenSubject string,
|
||||
wantDownstreamIDTokenGroups []string,
|
||||
wantDownstreamRequestedScopes []string,
|
||||
) (*fosite.Request, *openid.DefaultSession) {
|
||||
t.Helper()
|
||||
|
||||
// Get the authcode session back from storage so we can require that it was stored correctly.
|
||||
storedAuthorizeRequestFromAuthcode, err := oauthStore.GetAuthorizeCodeSession(context.Background(), storeKey, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that storage returned the expected concrete data types.
|
||||
storedRequestFromAuthcode, storedSessionFromAuthcode := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromAuthcode)
|
||||
|
||||
// Check which scopes were granted.
|
||||
if wantGrantedOpenidScope {
|
||||
require.Contains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid")
|
||||
} else {
|
||||
require.NotContains(t, storedRequestFromAuthcode.GetGrantedScopes(), "openid")
|
||||
}
|
||||
|
||||
// Check all the other fields of the stored request.
|
||||
require.NotEmpty(t, storedRequestFromAuthcode.ID)
|
||||
require.Equal(t, downstreamClientID, storedRequestFromAuthcode.Client.GetID())
|
||||
require.ElementsMatch(t, wantDownstreamRequestedScopes, storedRequestFromAuthcode.RequestedScope)
|
||||
require.Nil(t, storedRequestFromAuthcode.RequestedAudience)
|
||||
require.Empty(t, storedRequestFromAuthcode.GrantedAudience)
|
||||
require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequestFromAuthcode.Form)
|
||||
testutil.RequireTimeInDelta(t, time.Now(), storedRequestFromAuthcode.RequestedAt, timeComparisonFudgeFactor)
|
||||
|
||||
// We're not using these fields yet, so confirm that we did not set them (for now).
|
||||
require.Empty(t, storedSessionFromAuthcode.Subject)
|
||||
require.Empty(t, storedSessionFromAuthcode.Username)
|
||||
require.Empty(t, storedSessionFromAuthcode.Headers)
|
||||
|
||||
// The authcode that we are issuing should be good for the length of time that we declare in the fosite config.
|
||||
testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*3), storedSessionFromAuthcode.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor)
|
||||
require.Len(t, storedSessionFromAuthcode.ExpiresAt, 1)
|
||||
|
||||
// Now confirm the ID token claims.
|
||||
actualClaims := storedSessionFromAuthcode.Claims
|
||||
|
||||
// Check the user's identity, which are put into the downstream ID token's subject and groups claims.
|
||||
require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject)
|
||||
if wantDownstreamIDTokenGroups != nil {
|
||||
require.Len(t, actualClaims.Extra, 1)
|
||||
require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
|
||||
} else {
|
||||
require.Empty(t, actualClaims.Extra)
|
||||
require.NotContains(t, actualClaims.Extra, "groups")
|
||||
}
|
||||
|
||||
// Check the rest of the downstream ID token's claims. Fosite wants us to set these (in UTC time).
|
||||
testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.RequestedAt, timeComparisonFudgeFactor)
|
||||
testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.AuthTime, timeComparisonFudgeFactor)
|
||||
requestedAtZone, _ := actualClaims.RequestedAt.Zone()
|
||||
require.Equal(t, "UTC", requestedAtZone)
|
||||
authTimeZone, _ := actualClaims.AuthTime.Zone()
|
||||
require.Equal(t, "UTC", authTimeZone)
|
||||
|
||||
// Fosite will set these fields for us in the token endpoint based on the store session
|
||||
// information. Therefore, we assert that they are empty because we want the library to do the
|
||||
// lifting for us.
|
||||
require.Empty(t, actualClaims.Issuer)
|
||||
require.Nil(t, actualClaims.Audience)
|
||||
require.Empty(t, actualClaims.Nonce)
|
||||
require.Zero(t, actualClaims.ExpiresAt)
|
||||
require.Zero(t, actualClaims.IssuedAt)
|
||||
|
||||
// These are not needed yet.
|
||||
require.Empty(t, actualClaims.JTI)
|
||||
require.Empty(t, actualClaims.CodeHash)
|
||||
require.Empty(t, actualClaims.AccessTokenHash)
|
||||
require.Empty(t, actualClaims.AuthenticationContextClassReference)
|
||||
require.Empty(t, actualClaims.AuthenticationMethodsReference)
|
||||
|
||||
return storedRequestFromAuthcode, storedSessionFromAuthcode
|
||||
}
|
||||
|
||||
func validatePKCEStorage(
|
||||
t *testing.T,
|
||||
oauthStore *oidc.KubeStorage,
|
||||
storeKey string,
|
||||
storedRequestFromAuthcode *fosite.Request,
|
||||
storedSessionFromAuthcode *openid.DefaultSession,
|
||||
wantDownstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
storedAuthorizeRequestFromPKCE, err := oauthStore.GetPKCERequestSession(context.Background(), storeKey, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that storage returned the expected concrete data types.
|
||||
storedRequestFromPKCE, storedSessionFromPKCE := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromPKCE)
|
||||
|
||||
// The stored PKCE request should be the same as the stored authcode request.
|
||||
require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromPKCE.ID)
|
||||
require.Equal(t, storedSessionFromAuthcode, storedSessionFromPKCE)
|
||||
|
||||
// The stored PKCE request should also contain the PKCE challenge that the downstream sent us.
|
||||
require.Equal(t, wantDownstreamPKCEChallenge, storedRequestFromPKCE.Form.Get("code_challenge"))
|
||||
require.Equal(t, wantDownstreamPKCEChallengeMethod, storedRequestFromPKCE.Form.Get("code_challenge_method"))
|
||||
}
|
||||
|
||||
func validateIDSessionStorage(
|
||||
t *testing.T,
|
||||
oauthStore *oidc.KubeStorage,
|
||||
storeKey string,
|
||||
storedRequestFromAuthcode *fosite.Request,
|
||||
storedSessionFromAuthcode *openid.DefaultSession,
|
||||
wantDownstreamNonce string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that storage returned the expected concrete data types.
|
||||
storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession)
|
||||
|
||||
// The stored IDSession request should be the same as the stored authcode request.
|
||||
require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID)
|
||||
require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession)
|
||||
|
||||
// The stored IDSession request should also contain the nonce that the downstream sent us.
|
||||
require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce"))
|
||||
}
|
||||
|
||||
func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requester) (*fosite.Request, *openid.DefaultSession) {
|
||||
t.Helper()
|
||||
|
||||
storedRequest, ok := storedAuthorizeRequest.(*fosite.Request)
|
||||
require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest, &fosite.Request{})
|
||||
storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession)
|
||||
require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest.GetSession(), &openid.DefaultSession{})
|
||||
|
||||
return storedRequest, storedSession
|
||||
}
|
||||
|
||||
func requireAnyStringHasPrefix(t *testing.T, stringList []string, prefix string) {
|
||||
t.Helper()
|
||||
|
||||
containsPrefix := false
|
||||
for i := range stringList {
|
||||
if strings.HasPrefix(stringList[i], prefix) {
|
||||
containsPrefix = true
|
||||
}
|
||||
}
|
||||
require.Truef(t, containsPrefix, "list %v did not contain any strings with prefix %s", stringList, prefix)
|
||||
}
|
120
internal/oidc/kube_storage.go
Normal file
120
internal/oidc/kube_storage.go
Normal file
@ -0,0 +1,120 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/handler/oauth2"
|
||||
"github.com/ory/fosite/handler/openid"
|
||||
fositepkce "github.com/ory/fosite/handler/pkce"
|
||||
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
|
||||
"go.pinniped.dev/internal/constable"
|
||||
"go.pinniped.dev/internal/fositestorage/authorizationcode"
|
||||
"go.pinniped.dev/internal/fositestorage/openidconnect"
|
||||
"go.pinniped.dev/internal/fositestorage/pkce"
|
||||
)
|
||||
|
||||
const errKubeStorageNotImplemented = constable.Error("KubeStorage does not implement this method. It should not have been called.")
|
||||
|
||||
type KubeStorage struct {
|
||||
authorizationCodeStorage oauth2.AuthorizeCodeStorage
|
||||
pkceStorage fositepkce.PKCERequestStorage
|
||||
oidcStorage openid.OpenIDConnectRequestStorage
|
||||
}
|
||||
|
||||
func NewKubeStorage(secrets corev1client.SecretInterface) *KubeStorage {
|
||||
return &KubeStorage{
|
||||
authorizationCodeStorage: authorizationcode.New(secrets),
|
||||
pkceStorage: pkce.New(secrets),
|
||||
oidcStorage: openidconnect.New(secrets),
|
||||
}
|
||||
}
|
||||
|
||||
func (KubeStorage) RevokeRefreshToken(_ context.Context, _ string) error {
|
||||
return errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (KubeStorage) RevokeAccessToken(_ context.Context, _ string) error {
|
||||
return errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (KubeStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (KubeStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
|
||||
return nil, errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (KubeStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) {
|
||||
return errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (KubeStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (KubeStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
|
||||
return nil, errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (KubeStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) {
|
||||
return errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (k KubeStorage) CreateOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) error {
|
||||
return k.oidcStorage.CreateOpenIDConnectSession(ctx, authcode, requester)
|
||||
}
|
||||
|
||||
func (k KubeStorage) GetOpenIDConnectSession(ctx context.Context, authcode string, requester fosite.Requester) (fosite.Requester, error) {
|
||||
return k.oidcStorage.GetOpenIDConnectSession(ctx, authcode, requester)
|
||||
}
|
||||
|
||||
func (k KubeStorage) DeleteOpenIDConnectSession(ctx context.Context, authcode string) error {
|
||||
return k.oidcStorage.DeleteOpenIDConnectSession(ctx, authcode)
|
||||
}
|
||||
|
||||
func (k KubeStorage) GetPKCERequestSession(ctx context.Context, signature string, session fosite.Session) (fosite.Requester, error) {
|
||||
return k.pkceStorage.GetPKCERequestSession(ctx, signature, session)
|
||||
}
|
||||
|
||||
func (k KubeStorage) CreatePKCERequestSession(ctx context.Context, signature string, requester fosite.Requester) error {
|
||||
return k.pkceStorage.CreatePKCERequestSession(ctx, signature, requester)
|
||||
}
|
||||
|
||||
func (k KubeStorage) DeletePKCERequestSession(ctx context.Context, signature string) error {
|
||||
return k.pkceStorage.DeletePKCERequestSession(ctx, signature)
|
||||
}
|
||||
|
||||
func (k KubeStorage) CreateAuthorizeCodeSession(ctx context.Context, signature string, r fosite.Requester) (err error) {
|
||||
return k.authorizationCodeStorage.CreateAuthorizeCodeSession(ctx, signature, r)
|
||||
}
|
||||
|
||||
func (k KubeStorage) GetAuthorizeCodeSession(ctx context.Context, signature string, s fosite.Session) (request fosite.Requester, err error) {
|
||||
return k.authorizationCodeStorage.GetAuthorizeCodeSession(ctx, signature, s)
|
||||
}
|
||||
|
||||
func (k KubeStorage) InvalidateAuthorizeCodeSession(ctx context.Context, signature string) (err error) {
|
||||
return k.authorizationCodeStorage.InvalidateAuthorizeCodeSession(ctx, signature)
|
||||
}
|
||||
|
||||
func (KubeStorage) GetClient(_ context.Context, id string) (fosite.Client, error) {
|
||||
client := PinnipedCLIOIDCClient()
|
||||
if client.ID == id {
|
||||
return client, nil
|
||||
}
|
||||
return nil, fosite.ErrNotFound
|
||||
}
|
||||
|
||||
func (KubeStorage) ClientAssertionJWTValid(_ context.Context, _ string) error {
|
||||
return errKubeStorageNotImplemented
|
||||
}
|
||||
|
||||
func (KubeStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error {
|
||||
return errKubeStorageNotImplemented
|
||||
}
|
@ -12,16 +12,16 @@ import (
|
||||
"go.pinniped.dev/internal/constable"
|
||||
)
|
||||
|
||||
const errNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.")
|
||||
const errNullStorageNotImplemented = constable.Error("NullStorage does not implement this method. It should not have been called.")
|
||||
|
||||
type NullStorage struct{}
|
||||
|
||||
func (NullStorage) RevokeRefreshToken(_ context.Context, _ string) error {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) RevokeAccessToken(_ context.Context, _ string) error {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
|
||||
@ -29,11 +29,11 @@ func (NullStorage) CreateRefreshTokenSession(_ context.Context, _ string, _ fosi
|
||||
}
|
||||
|
||||
func (NullStorage) GetRefreshTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
|
||||
return nil, errNotImplemented
|
||||
return nil, errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) DeleteRefreshTokenSession(_ context.Context, _ string) (err error) {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
|
||||
@ -41,11 +41,11 @@ func (NullStorage) CreateAccessTokenSession(_ context.Context, _ string, _ fosit
|
||||
}
|
||||
|
||||
func (NullStorage) GetAccessTokenSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
|
||||
return nil, errNotImplemented
|
||||
return nil, errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) DeleteAccessTokenSession(_ context.Context, _ string) (err error) {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) error {
|
||||
@ -53,15 +53,15 @@ func (NullStorage) CreateOpenIDConnectSession(_ context.Context, _ string, _ fos
|
||||
}
|
||||
|
||||
func (NullStorage) GetOpenIDConnectSession(_ context.Context, _ string, _ fosite.Requester) (fosite.Requester, error) {
|
||||
return nil, errNotImplemented
|
||||
return nil, errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) DeleteOpenIDConnectSession(_ context.Context, _ string) error {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) GetPKCERequestSession(_ context.Context, _ string, _ fosite.Session) (fosite.Requester, error) {
|
||||
return nil, errNotImplemented
|
||||
return nil, errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosite.Requester) error {
|
||||
@ -69,7 +69,7 @@ func (NullStorage) CreatePKCERequestSession(_ context.Context, _ string, _ fosit
|
||||
}
|
||||
|
||||
func (NullStorage) DeletePKCERequestSession(_ context.Context, _ string) error {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Requester) (err error) {
|
||||
@ -77,11 +77,11 @@ func (NullStorage) CreateAuthorizeCodeSession(_ context.Context, _ string, _ fos
|
||||
}
|
||||
|
||||
func (NullStorage) GetAuthorizeCodeSession(_ context.Context, _ string, _ fosite.Session) (request fosite.Requester, err error) {
|
||||
return nil, errNotImplemented
|
||||
return nil, errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) InvalidateAuthorizeCodeSession(_ context.Context, _ string) (err error) {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error) {
|
||||
@ -93,9 +93,9 @@ func (NullStorage) GetClient(_ context.Context, id string) (fosite.Client, error
|
||||
}
|
||||
|
||||
func (NullStorage) ClientAssertionJWTValid(_ context.Context, _ string) error {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
||||
func (NullStorage) SetClientAssertionJWT(_ context.Context, _ string, _ time.Time) error {
|
||||
return errNotImplemented
|
||||
return errNullStorageNotImplemented
|
||||
}
|
||||
|
@ -10,15 +10,72 @@ import (
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/compose"
|
||||
|
||||
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
const (
|
||||
WellKnownEndpointPath = "/.well-known/openid-configuration"
|
||||
AuthorizationEndpointPath = "/oauth2/authorize"
|
||||
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
|
||||
CallbackEndpointPath = "/callback"
|
||||
JWKSEndpointPath = "/jwks.json"
|
||||
)
|
||||
|
||||
const (
|
||||
// Just in case we need to make a breaking change to the format of the upstream state param,
|
||||
// we are including a format version number. This gives the opportunity for a future version of Pinniped
|
||||
// to have the consumer of this format decide to reject versions that it doesn't understand.
|
||||
UpstreamStateParamFormatVersion = "1"
|
||||
|
||||
// The `name` passed to the encoder for encoding the upstream state param value. This name is short
|
||||
// because it will be encoded into the upstream state param value and we're trying to keep that small.
|
||||
UpstreamStateParamEncodingName = "s"
|
||||
|
||||
// CSRFCookieName is the name of the browser cookie which shall hold our CSRF value.
|
||||
// The `__Host` prefix has a special meaning. See:
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes.
|
||||
CSRFCookieName = "__Host-pinniped-csrf"
|
||||
|
||||
// CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF
|
||||
// cookie contents.
|
||||
CSRFCookieEncodingName = "csrf"
|
||||
)
|
||||
|
||||
// Encoder is the encoding side of the securecookie.Codec interface.
|
||||
type Encoder interface {
|
||||
Encode(name string, value interface{}) (string, error)
|
||||
}
|
||||
|
||||
// Decoder is the decoding side of the securecookie.Codec interface.
|
||||
type Decoder interface {
|
||||
Decode(name, value string, into interface{}) error
|
||||
}
|
||||
|
||||
// Codec is both the encoding and decoding sides of the securecookie.Codec interface. It is
|
||||
// interface'd here so that we properly wrap the securecookie dependency.
|
||||
type Codec interface {
|
||||
Encoder
|
||||
Decoder
|
||||
}
|
||||
|
||||
// UpstreamStateParamData is the format of the state parameter that we use when we communicate to an
|
||||
// upstream OIDC provider.
|
||||
//
|
||||
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on
|
||||
// the state param.
|
||||
type UpstreamStateParamData struct {
|
||||
AuthParams string `json:"p"`
|
||||
UpstreamName string `json:"u"`
|
||||
Nonce nonce.Nonce `json:"n"`
|
||||
CSRFToken csrftoken.CSRFToken `json:"c"`
|
||||
PKCECode pkce.Code `json:"k"`
|
||||
FormatVersion string `json:"v"`
|
||||
}
|
||||
|
||||
func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient {
|
||||
return &fosite.DefaultOpenIDConnectClient{
|
||||
DefaultClient: &fosite.DefaultClient{
|
||||
@ -34,8 +91,8 @@ func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient {
|
||||
}
|
||||
|
||||
func FositeOauth2Helper(
|
||||
issuerURL string,
|
||||
oauthStore fosite.Storage,
|
||||
oauthStore interface{},
|
||||
issuer string,
|
||||
hmacSecretOfLengthAtLeast32 []byte,
|
||||
jwtSigningKey *ecdsa.PrivateKey,
|
||||
) fosite.OAuth2Provider {
|
||||
@ -47,7 +104,7 @@ func FositeOauth2Helper(
|
||||
|
||||
RefreshTokenLifespan: 16 * time.Hour, // long enough for a single workday
|
||||
|
||||
IDTokenIssuer: issuerURL,
|
||||
IDTokenIssuer: issuer,
|
||||
TokenURL: "", // TODO set once we have this endpoint written
|
||||
|
||||
ScopeStrategy: fosite.ExactScopeStrategy, // be careful and only support exact string matching for scopes
|
||||
@ -75,3 +132,7 @@ func FositeOauth2Helper(
|
||||
compose.OAuth2PKCEFactory,
|
||||
)
|
||||
}
|
||||
|
||||
type IDPListGetter interface {
|
||||
GetIDPList() []provider.UpstreamOIDCIdentityProviderI
|
||||
}
|
||||
|
129
internal/oidc/oidctestutil/oidc.go
Normal file
129
internal/oidc/oidctestutil/oidc.go
Normal file
@ -0,0 +1,129 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package oidctestutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
// Test helpers for the OIDC package.
|
||||
|
||||
// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to
|
||||
// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc().
|
||||
type ExchangeAuthcodeAndValidateTokenArgs struct {
|
||||
Ctx context.Context
|
||||
Authcode string
|
||||
PKCECodeVerifier pkce.Code
|
||||
ExpectedIDTokenNonce nonce.Nonce
|
||||
RedirectURI string
|
||||
}
|
||||
|
||||
type TestUpstreamOIDCIdentityProvider struct {
|
||||
Name string
|
||||
ClientID string
|
||||
AuthorizationURL url.URL
|
||||
UsernameClaim string
|
||||
GroupsClaim string
|
||||
Scopes []string
|
||||
ExchangeAuthcodeAndValidateTokensFunc func(
|
||||
ctx context.Context,
|
||||
authcode string,
|
||||
pkceCodeVerifier pkce.Code,
|
||||
expectedIDTokenNonce nonce.Nonce,
|
||||
) (oidctypes.Token, map[string]interface{}, error)
|
||||
|
||||
exchangeAuthcodeAndValidateTokensCallCount int
|
||||
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
|
||||
return u.Name
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) GetClientID() string {
|
||||
return u.ClientID
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL {
|
||||
return &u.AuthorizationURL
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) GetScopes() []string {
|
||||
return u.Scopes
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) GetUsernameClaim() string {
|
||||
return u.UsernameClaim
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) GetGroupsClaim() string {
|
||||
return u.GroupsClaim
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
|
||||
ctx context.Context,
|
||||
authcode string,
|
||||
pkceCodeVerifier pkce.Code,
|
||||
expectedIDTokenNonce nonce.Nonce,
|
||||
redirectURI string,
|
||||
) (oidctypes.Token, map[string]interface{}, error) {
|
||||
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
|
||||
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
|
||||
}
|
||||
u.exchangeAuthcodeAndValidateTokensCallCount++
|
||||
u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{
|
||||
Ctx: ctx,
|
||||
Authcode: authcode,
|
||||
PKCECodeVerifier: pkceCodeVerifier,
|
||||
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
||||
RedirectURI: redirectURI,
|
||||
})
|
||||
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int {
|
||||
return u.exchangeAuthcodeAndValidateTokensCallCount
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs {
|
||||
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
|
||||
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
|
||||
}
|
||||
return u.exchangeAuthcodeAndValidateTokensArgs[call]
|
||||
}
|
||||
|
||||
func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
|
||||
idpProvider := provider.NewDynamicUpstreamIDPProvider()
|
||||
upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders))
|
||||
for i := range upstreamOIDCIdentityProviders {
|
||||
upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i])
|
||||
}
|
||||
idpProvider.SetIDPList(upstreams)
|
||||
return idpProvider
|
||||
}
|
||||
|
||||
// Declare a separate type from the production code to ensure that the state param's contents was serialized
|
||||
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
|
||||
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
|
||||
// assertions about the redirect URL in this test.
|
||||
type ExpectedUpstreamStateParamFormat struct {
|
||||
P string `json:"p"`
|
||||
U string `json:"u"`
|
||||
N string `json:"n"`
|
||||
C string `json:"c"`
|
||||
K string `json:"k"`
|
||||
V string `json:"v"`
|
||||
}
|
@ -4,48 +4,73 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
type UpstreamOIDCIdentityProvider struct {
|
||||
type UpstreamOIDCIdentityProviderI interface {
|
||||
// A name for this upstream provider, which will be used as a component of the path for the callback endpoint
|
||||
// hosted by the Supervisor.
|
||||
Name string
|
||||
GetName() string
|
||||
|
||||
// The Oauth client ID registered with the upstream provider to be used in the authorization flow.
|
||||
ClientID string
|
||||
// The Oauth client ID registered with the upstream provider to be used in the authorization code flow.
|
||||
GetClientID() string
|
||||
|
||||
// The Authorization Endpoint fetched from discovery.
|
||||
AuthorizationURL url.URL
|
||||
GetAuthorizationURL() *url.URL
|
||||
|
||||
// Scopes to request in authorization flow.
|
||||
Scopes []string
|
||||
GetScopes() []string
|
||||
|
||||
// ID Token username claim name. May return empty string, in which case we will use some reasonable defaults.
|
||||
GetUsernameClaim() string
|
||||
|
||||
// ID Token groups claim name. May return empty string, in which case we won't try to read groups from the upstream provider.
|
||||
GetGroupsClaim() string
|
||||
|
||||
// Performs upstream OIDC authorization code exchange and token validation.
|
||||
// Returns the validated raw tokens as well as the parsed claims of the ID token.
|
||||
ExchangeAuthcodeAndValidateTokens(
|
||||
ctx context.Context,
|
||||
authcode string,
|
||||
pkceCodeVerifier pkce.Code,
|
||||
expectedIDTokenNonce nonce.Nonce,
|
||||
redirectURI string,
|
||||
) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error)
|
||||
|
||||
ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error)
|
||||
}
|
||||
|
||||
type DynamicUpstreamIDPProvider interface {
|
||||
SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider)
|
||||
GetIDPList() []UpstreamOIDCIdentityProvider
|
||||
SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI)
|
||||
GetIDPList() []UpstreamOIDCIdentityProviderI
|
||||
}
|
||||
|
||||
type dynamicUpstreamIDPProvider struct {
|
||||
oidcProviders []UpstreamOIDCIdentityProvider
|
||||
oidcProviders []UpstreamOIDCIdentityProviderI
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider {
|
||||
return &dynamicUpstreamIDPProvider{
|
||||
oidcProviders: []UpstreamOIDCIdentityProvider{},
|
||||
oidcProviders: []UpstreamOIDCIdentityProviderI{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) {
|
||||
func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) {
|
||||
p.mutex.Lock() // acquire a write lock
|
||||
defer p.mutex.Unlock()
|
||||
p.oidcProviders = oidcIDPs
|
||||
}
|
||||
|
||||
func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProvider {
|
||||
func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProviderI {
|
||||
p.mutex.RLock() // acquire a read lock
|
||||
defer p.mutex.RUnlock()
|
||||
return p.oidcProviders
|
||||
|
@ -9,9 +9,11 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
|
||||
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/auth"
|
||||
"go.pinniped.dev/internal/oidc/callback"
|
||||
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||
"go.pinniped.dev/internal/oidc/discovery"
|
||||
"go.pinniped.dev/internal/oidc/jwks"
|
||||
@ -30,19 +32,26 @@ type Manager struct {
|
||||
providerHandlers map[string]http.Handler // map of all routes for all providers
|
||||
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
||||
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
|
||||
idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs
|
||||
idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
|
||||
secretsClient corev1client.SecretInterface
|
||||
}
|
||||
|
||||
// NewManager returns an empty Manager.
|
||||
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
||||
// dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data.
|
||||
// idpListGetter will be used as an in-memory cache of currently configured upstream IDPs.
|
||||
func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager {
|
||||
func NewManager(
|
||||
nextHandler http.Handler,
|
||||
dynamicJWKSProvider jwks.DynamicJWKSProvider,
|
||||
idpListGetter oidc.IDPListGetter,
|
||||
secretsClient corev1client.SecretInterface,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
providerHandlers: make(map[string]http.Handler),
|
||||
nextHandler: nextHandler,
|
||||
dynamicJWKSProvider: dynamicJWKSProvider,
|
||||
idpListGetter: idpListGetter,
|
||||
secretsClient: secretsClient,
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,20 +71,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
||||
m.providerHandlers = make(map[string]http.Handler)
|
||||
|
||||
for _, incomingProvider := range oidcProviders {
|
||||
wellKnownURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.WellKnownEndpointPath
|
||||
m.providerHandlers[wellKnownURL] = discovery.NewHandler(incomingProvider.Issuer())
|
||||
issuer := incomingProvider.Issuer()
|
||||
issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath()
|
||||
|
||||
jwksURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.JWKSEndpointPath
|
||||
m.providerHandlers[jwksURL] = jwks.NewHandler(incomingProvider.Issuer(), m.dynamicJWKSProvider)
|
||||
fositeHMACSecretForThisProvider := []byte("some secret - must have at least 32 bytes") // TODO replace this secret
|
||||
|
||||
// Use NullStorage for the authorize endpoint because we do not actually want to store anything until
|
||||
// the upstream callback endpoint is called later.
|
||||
oauthHelper := oidc.FositeOauth2Helper(
|
||||
incomingProvider.Issuer(),
|
||||
oidc.NullStorage{},
|
||||
[]byte("some secret - must have at least 32 bytes"), // TODO replace this secret
|
||||
nil, // TODO: inject me properly
|
||||
)
|
||||
oauthHelperWithNullStorage := oidc.FositeOauth2Helper(oidc.NullStorage{}, issuer, fositeHMACSecretForThisProvider, nil)
|
||||
|
||||
// For all the other endpoints, make another oauth helper with exactly the same settings except use real storage.
|
||||
oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(oidc.NewKubeStorage(m.secretsClient), issuer, fositeHMACSecretForThisProvider, nil)
|
||||
|
||||
// TODO use different codecs for the state and the cookie, because:
|
||||
// 1. we would like to state to have an embedded expiration date while the cookie does not need that
|
||||
@ -86,10 +92,30 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
||||
var encoder = securecookie.New(encoderHashKey, encoderBlockKey)
|
||||
encoder.SetSerializer(securecookie.JSONEncoder{})
|
||||
|
||||
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
|
||||
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder)
|
||||
m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer)
|
||||
|
||||
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
|
||||
m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider)
|
||||
|
||||
m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler(
|
||||
issuer,
|
||||
m.idpListGetter,
|
||||
oauthHelperWithNullStorage,
|
||||
csrftoken.Generate,
|
||||
pkce.Generate,
|
||||
nonce.Generate,
|
||||
encoder,
|
||||
encoder,
|
||||
)
|
||||
|
||||
m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler(
|
||||
m.idpListGetter,
|
||||
oauthHelperWithKubeStorage,
|
||||
encoder,
|
||||
encoder,
|
||||
issuer+oidc.CallbackEndpointPath,
|
||||
)
|
||||
|
||||
plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@ -15,12 +16,17 @@ import (
|
||||
"github.com/sclevine/spec"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"k8s.io/client-go/kubernetes/fake"
|
||||
|
||||
"go.pinniped.dev/internal/here"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/discovery"
|
||||
"go.pinniped.dev/internal/oidc/jwks"
|
||||
"go.pinniped.dev/internal/oidc/oidctestutil"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
func TestManager(t *testing.T) {
|
||||
@ -31,6 +37,7 @@ func TestManager(t *testing.T) {
|
||||
nextHandler http.HandlerFunc
|
||||
fallbackHandlerWasCalled bool
|
||||
dynamicJWKSProvider jwks.DynamicJWKSProvider
|
||||
kubeClient *fake.Clientset
|
||||
)
|
||||
|
||||
const (
|
||||
@ -41,6 +48,7 @@ func TestManager(t *testing.T) {
|
||||
issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path"
|
||||
issuer2KeyID = "issuer2-key"
|
||||
upstreamIDPAuthorizationURL = "https://test-upstream.com/auth"
|
||||
downstreamRedirectURL = "http://127.0.0.1:12345/callback"
|
||||
)
|
||||
|
||||
newGetRequest := func(url string) *http.Request {
|
||||
@ -64,7 +72,7 @@ func TestManager(t *testing.T) {
|
||||
r.Equal(expectedIssuerInResponse, parsedDiscoveryResult.Issuer)
|
||||
}
|
||||
|
||||
requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) {
|
||||
requireAuthorizationRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedRedirectLocationPrefix string) (string, string) {
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.AuthorizationEndpointPath+requestURLSuffix))
|
||||
@ -79,6 +87,58 @@ func TestManager(t *testing.T) {
|
||||
"actual location %s did not start with expected prefix %s",
|
||||
actualLocation, expectedRedirectLocationPrefix,
|
||||
)
|
||||
|
||||
parsedLocation, err := url.Parse(actualLocation)
|
||||
r.NoError(err)
|
||||
redirectStateParam := parsedLocation.Query().Get("state")
|
||||
r.NotEmpty(redirectStateParam)
|
||||
|
||||
cookieValueAndDirectivesSplit := strings.SplitN(recorder.Header().Get("Set-Cookie"), ";", 2)
|
||||
r.Len(cookieValueAndDirectivesSplit, 2)
|
||||
cookieKeyValueSplit := strings.Split(cookieValueAndDirectivesSplit[0], "=")
|
||||
r.Len(cookieKeyValueSplit, 2)
|
||||
csrfCookieName := cookieKeyValueSplit[0]
|
||||
r.Equal("__Host-pinniped-csrf", csrfCookieName)
|
||||
csrfCookieValue := cookieKeyValueSplit[1]
|
||||
r.NotEmpty(csrfCookieValue)
|
||||
|
||||
// Return the important parts of the response so we can use them in our next request to the callback endpoint
|
||||
return csrfCookieValue, redirectStateParam
|
||||
}
|
||||
|
||||
requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix, csrfCookieValue string) {
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
numberOfKubeActionsBeforeThisRequest := len(kubeClient.Actions())
|
||||
|
||||
getRequest := newGetRequest(requestIssuer + oidc.CallbackEndpointPath + requestURLSuffix)
|
||||
getRequest.AddCookie(&http.Cookie{
|
||||
Name: "__Host-pinniped-csrf",
|
||||
Value: csrfCookieValue,
|
||||
})
|
||||
subject.ServeHTTP(recorder, getRequest)
|
||||
|
||||
r.False(fallbackHandlerWasCalled)
|
||||
|
||||
// Check just enough of the response to ensure that we wired up the callback endpoint correctly.
|
||||
// The endpoint's own unit tests cover everything else.
|
||||
r.Equal(http.StatusFound, recorder.Code)
|
||||
actualLocation := recorder.Header().Get("Location")
|
||||
r.True(
|
||||
strings.HasPrefix(actualLocation, downstreamRedirectURL),
|
||||
"actual location %s did not start with expected prefix %s",
|
||||
actualLocation, downstreamRedirectURL,
|
||||
)
|
||||
parsedLocation, err := url.Parse(actualLocation)
|
||||
r.NoError(err)
|
||||
actualLocationQueryParams := parsedLocation.Query()
|
||||
r.Contains(actualLocationQueryParams, "code")
|
||||
r.Equal("openid", actualLocationQueryParams.Get("scope"))
|
||||
r.Equal("some-state-value-that-is-32-byte", actualLocationQueryParams.Get("state"))
|
||||
|
||||
// Make sure that we wired up the callback endpoint to use kube storage for fosite sessions.
|
||||
r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3,
|
||||
"did not perform any kube actions during the callback request, but should have")
|
||||
}
|
||||
|
||||
requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) {
|
||||
@ -107,17 +167,27 @@ func TestManager(t *testing.T) {
|
||||
|
||||
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
|
||||
r.NoError(err)
|
||||
idpListGetter := provider.NewDynamicUpstreamIDPProvider()
|
||||
idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{
|
||||
{
|
||||
idpListGetter := oidctestutil.NewIDPListGetter(&oidctestutil.TestUpstreamOIDCIdentityProvider{
|
||||
Name: "test-idp",
|
||||
ClientID: "test-client-id",
|
||||
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
|
||||
Scopes: []string{"test-scope"},
|
||||
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
||||
return oidctypes.Token{},
|
||||
map[string]interface{}{
|
||||
"iss": "https://some-issuer.com",
|
||||
"sub": "some-subject",
|
||||
"username": "test-username",
|
||||
"groups": "test-group1",
|
||||
},
|
||||
nil
|
||||
},
|
||||
})
|
||||
|
||||
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter)
|
||||
kubeClient = fake.NewSimpleClientset()
|
||||
secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
|
||||
|
||||
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter, secretsClient)
|
||||
})
|
||||
|
||||
when("given no providers via SetProviders()", func() {
|
||||
@ -164,7 +234,6 @@ func TestManager(t *testing.T) {
|
||||
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID)
|
||||
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID)
|
||||
|
||||
authRedirectURI := "http://127.0.0.1/callback"
|
||||
authRequestParams := "?" + url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"openid profile email"},
|
||||
@ -173,7 +242,7 @@ func TestManager(t *testing.T) {
|
||||
"nonce": []string{"some-nonce-value"},
|
||||
"code_challenge": []string{"some-challenge"},
|
||||
"code_challenge_method": []string{"S256"},
|
||||
"redirect_uri": []string{authRedirectURI},
|
||||
"redirect_uri": []string{downstreamRedirectURL},
|
||||
}.Encode()
|
||||
|
||||
requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL)
|
||||
@ -181,7 +250,20 @@ func TestManager(t *testing.T) {
|
||||
|
||||
// Hostnames are case-insensitive, so test that we can handle that.
|
||||
requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
|
||||
csrfCookieValue, upstreamStateParam :=
|
||||
requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
|
||||
|
||||
callbackRequestParams := "?" + url.Values{
|
||||
"code": []string{"some-fake-code"},
|
||||
"state": []string{upstreamStateParam},
|
||||
}.Encode()
|
||||
|
||||
requireCallbackRequestToBeHandled(issuer1, callbackRequestParams, csrfCookieValue)
|
||||
requireCallbackRequestToBeHandled(issuer2, callbackRequestParams, csrfCookieValue)
|
||||
|
||||
// // Hostnames are case-insensitive, so test that we can handle that.
|
||||
requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams, csrfCookieValue)
|
||||
requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams, csrfCookieValue)
|
||||
}
|
||||
|
||||
when("given some valid providers via SetProviders()", func() {
|
||||
|
24
internal/testutil/assertions.go
Normal file
24
internal/testutil/assertions.go
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func RequireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
|
||||
require.InDeltaf(t,
|
||||
float64(t1.UnixNano()),
|
||||
float64(t2.UnixNano()),
|
||||
float64(delta.Nanoseconds()),
|
||||
"expected %s and %s to be < %s apart, but they are %s apart",
|
||||
t1.Format(time.RFC3339Nano),
|
||||
t2.Format(time.RFC3339Nano),
|
||||
delta.String(),
|
||||
t1.Sub(t2).String(),
|
||||
)
|
||||
}
|
117
internal/upstreamoidc/upstreamoidc.go
Normal file
117
internal/upstreamoidc/upstreamoidc.go
Normal file
@ -0,0 +1,117 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package upstreamoidc implements an abstraction of upstream OIDC provider interactions.
|
||||
package upstreamoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
func New(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
return &ProviderConfig{Config: config, Provider: provider, Client: client}
|
||||
}
|
||||
|
||||
// ProviderConfig holds the active configuration of an upstream OIDC provider.
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
UsernameClaim string
|
||||
GroupsClaim string
|
||||
Config *oauth2.Config
|
||||
Provider interface {
|
||||
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
||||
}
|
||||
Client *http.Client
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetClientID() string {
|
||||
return p.Config.ClientID
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetAuthorizationURL() *url.URL {
|
||||
result, _ := url.Parse(p.Config.Endpoint.AuthURL)
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetScopes() []string {
|
||||
return p.Config.Scopes
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetUsernameClaim() string {
|
||||
return p.UsernameClaim
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) GetGroupsClaim() string {
|
||||
return p.GroupsClaim
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) {
|
||||
tok, err := p.Config.Exchange(
|
||||
oidc.ClientContext(ctx, p.Client),
|
||||
authcode,
|
||||
pkceCodeVerifier.Verifier(),
|
||||
oauth2.SetAuthURLParam("redirect_uri", redirectURI),
|
||||
)
|
||||
if err != nil {
|
||||
return oidctypes.Token{}, nil, err
|
||||
}
|
||||
|
||||
return p.ValidateToken(ctx, tok, expectedIDTokenNonce)
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) {
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
if !hasIDTok {
|
||||
return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok)
|
||||
if err != nil {
|
||||
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if expectedIDTokenNonce != "" {
|
||||
if err := expectedIDTokenNonce.Validate(validated); err != nil {
|
||||
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
}
|
||||
}
|
||||
|
||||
var validatedClaims map[string]interface{}
|
||||
if err := validated.Claims(&validatedClaims); err != nil {
|
||||
return oidctypes.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err)
|
||||
}
|
||||
|
||||
return oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: tok.AccessToken,
|
||||
Type: tok.TokenType,
|
||||
Expiry: metav1.NewTime(tok.Expiry),
|
||||
},
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: tok.RefreshToken,
|
||||
},
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: idTok,
|
||||
Expiry: metav1.NewTime(validated.Expiry),
|
||||
},
|
||||
}, validatedClaims, nil
|
||||
}
|
223
internal/upstreamoidc/upstreamoidc_test.go
Normal file
223
internal/upstreamoidc/upstreamoidc_test.go
Normal file
@ -0,0 +1,223 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package upstreamoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/mocks/mockkeyset"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
func TestProviderConfig(t *testing.T) {
|
||||
t.Run("getters get", func(t *testing.T) {
|
||||
p := ProviderConfig{
|
||||
Name: "test-name",
|
||||
UsernameClaim: "test-username-claim",
|
||||
GroupsClaim: "test-groups-claim",
|
||||
Config: &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
Endpoint: oauth2.Endpoint{AuthURL: "https://example.com"},
|
||||
Scopes: []string{"scope1", "scope2"},
|
||||
},
|
||||
}
|
||||
require.Equal(t, "test-name", p.GetName())
|
||||
require.Equal(t, "test-client-id", p.GetClientID())
|
||||
require.Equal(t, "https://example.com", p.GetAuthorizationURL().String())
|
||||
require.ElementsMatch(t, []string{"scope1", "scope2"}, p.GetScopes())
|
||||
require.Equal(t, "test-username-claim", p.GetUsernameClaim())
|
||||
require.Equal(t, "test-groups-claim", p.GetGroupsClaim())
|
||||
})
|
||||
|
||||
const (
|
||||
// Test JWTs generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
invalidAccessTokenHashIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg" //nolint: gosec
|
||||
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
invalidNonceIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ" //nolint: gosec
|
||||
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"foo": "bar", "bat": "baz"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
validIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImJhdCI6ImJheiIsImZvbyI6ImJhciIsImlhdCI6MTYwNjc2ODU5MywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDY3Njg1OTMsInN1YiI6InRlc3QtdXNlciJ9.DuqVZ7pGhHqKz7gNr4j2W1s1N8YrSltktH4wW19L4oD1OE2-O72jAnNj5xdjilsa8l7h9ox-5sMF0Tkh3BdRlHQK9dEtNm9tW-JreUnWJ3LCqUs-LZp4NG7edvq2sH_1Bn7O2_NQV51s8Pl04F60CndjQ4NM-6WkqDQTKyY6vJXU7idvM-6TM2HJZK-Na88cOJ9KIK37tL5DhcbsHVF47Dq8uPZ0KbjNQjJLAIi_1GeQBgc6yJhDUwRY4Xu6S0dtTHA6xTI8oSXoamt4bkViEHfJBp97LZQiNz8mku5pVc0aNwP1p4hMHxRHhLXrJjbh-Hx4YFjxtOnIq9t1mHlD4A" //nolint: gosec
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authCode string
|
||||
expectNonce nonce.Nonce
|
||||
returnIDTok string
|
||||
wantErr string
|
||||
wantToken oidctypes.Token
|
||||
wantClaims map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "exchange fails with network error",
|
||||
authCode: "invalid-auth-code",
|
||||
wantErr: "oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n",
|
||||
},
|
||||
{
|
||||
name: "missing ID token",
|
||||
authCode: "valid",
|
||||
wantErr: "received response missing ID token",
|
||||
},
|
||||
{
|
||||
name: "invalid ID token",
|
||||
authCode: "valid",
|
||||
returnIDTok: "invalid-jwt",
|
||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||
},
|
||||
{
|
||||
name: "invalid access token hash",
|
||||
authCode: "valid",
|
||||
returnIDTok: invalidAccessTokenHashIDToken,
|
||||
wantErr: "received invalid ID token: access token hash does not match value in ID token",
|
||||
},
|
||||
{
|
||||
name: "invalid nonce",
|
||||
authCode: "valid",
|
||||
expectNonce: "test-nonce",
|
||||
returnIDTok: invalidNonceIDToken,
|
||||
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
|
||||
},
|
||||
{
|
||||
name: "invalid nonce but not checked",
|
||||
authCode: "valid",
|
||||
expectNonce: "",
|
||||
returnIDTok: invalidNonceIDToken,
|
||||
wantToken: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Expiry: metav1.Time{},
|
||||
},
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: invalidNonceIDToken,
|
||||
Expiry: metav1.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
authCode: "valid",
|
||||
returnIDTok: validIDToken,
|
||||
wantToken: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Expiry: metav1.Time{},
|
||||
},
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: validIDToken,
|
||||
Expiry: metav1.Time{},
|
||||
},
|
||||
},
|
||||
wantClaims: map[string]interface{}{
|
||||
"foo": "bar",
|
||||
"bat": "baz",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.NoError(t, r.ParseForm())
|
||||
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
||||
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
|
||||
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
|
||||
require.NotEmpty(t, r.Form.Get("code"))
|
||||
if r.Form.Get("code") != "valid" {
|
||||
http.Error(w, "invalid authorization code", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
var response struct {
|
||||
oauth2.Token
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
response.AccessToken = "test-access-token"
|
||||
response.RefreshToken = "test-refresh-token"
|
||||
response.Expiry = time.Now().Add(time.Hour)
|
||||
response.IDToken = tt.returnIDTok
|
||||
w.Header().Set("content-type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||
}))
|
||||
t.Cleanup(tokenServer.Close)
|
||||
|
||||
p := ProviderConfig{
|
||||
Name: "test-name",
|
||||
UsernameClaim: "test-username-claim",
|
||||
GroupsClaim: "test-groups-claim",
|
||||
Config: &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://example.com",
|
||||
TokenURL: tokenServer.URL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
Scopes: []string{"scope1", "scope2"},
|
||||
},
|
||||
Provider: &mockProvider{},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback")
|
||||
if tt.wantErr != "" {
|
||||
require.EqualError(t, err, tt.wantErr)
|
||||
require.Equal(t, oidctypes.Token{}, tok)
|
||||
require.Nil(t, claims)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantToken, tok)
|
||||
|
||||
for k, v := range tt.wantClaims {
|
||||
require.Equal(t, v, claims[k])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else.
|
||||
func mockVerifier() *oidc.IDTokenVerifier {
|
||||
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil))
|
||||
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().
|
||||
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
|
||||
jws, err := jose.ParseSigned(jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.UnsafePayloadWithoutVerification(), nil
|
||||
})
|
||||
|
||||
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
|
||||
SkipIssuerCheck: true,
|
||||
SkipExpiryCheck: true,
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
}
|
||||
|
||||
type mockProvider struct{}
|
||||
|
||||
func (m *mockProvider) Verifier(_ *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
|
@ -17,6 +17,7 @@ import (
|
||||
"sigs.k8s.io/yaml"
|
||||
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -48,7 +49,7 @@ type (
|
||||
Key oidcclient.SessionCacheKey `json:"key"`
|
||||
CreationTimestamp metav1.Time `json:"creationTimestamp"`
|
||||
LastUsedTimestamp metav1.Time `json:"lastUsedTimestamp"`
|
||||
Tokens oidcclient.Token `json:"tokens"`
|
||||
Tokens oidctypes.Token `json:"tokens"`
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
// validSession should be the same data as `testdata/valid.yaml`.
|
||||
@ -28,17 +29,17 @@ var validSession = sessionCache{
|
||||
},
|
||||
CreationTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 42, 7, 0, time.UTC).Local()),
|
||||
LastUsedTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 45, 31, 0, time.UTC).Local()),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 46, 30, 0, time.UTC).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -140,8 +141,8 @@ func TestNormalized(t *testing.T) {
|
||||
// ID token is empty, but not nil.
|
||||
{
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
IDToken: &oidcclient.IDToken{
|
||||
Tokens: oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Minute)),
|
||||
},
|
||||
@ -150,8 +151,8 @@ func TestNormalized(t *testing.T) {
|
||||
// ID token is expired.
|
||||
{
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
IDToken: &oidcclient.IDToken{
|
||||
Tokens: oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(-1 * time.Minute)),
|
||||
},
|
||||
@ -160,8 +161,8 @@ func TestNormalized(t *testing.T) {
|
||||
// Access token is empty, but not nil.
|
||||
{
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Minute)),
|
||||
},
|
||||
@ -170,8 +171,8 @@ func TestNormalized(t *testing.T) {
|
||||
// Access token is expired.
|
||||
{
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Expiry: metav1.NewTime(now.Add(-1 * time.Minute)),
|
||||
},
|
||||
@ -180,8 +181,8 @@ func TestNormalized(t *testing.T) {
|
||||
// Refresh token is empty, but not nil.
|
||||
{
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
Tokens: oidctypes.Token{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "",
|
||||
},
|
||||
},
|
||||
@ -189,8 +190,8 @@ func TestNormalized(t *testing.T) {
|
||||
// Session has a refresh token but it hasn't been used in >90 days.
|
||||
{
|
||||
LastUsedTimestamp: metav1.NewTime(now.AddDate(-1, 0, 0)),
|
||||
Tokens: oidcclient.Token{
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
Tokens: oidctypes.Token{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -199,8 +200,8 @@ func TestNormalized(t *testing.T) {
|
||||
{
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
Tokens: oidctypes.Token{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token2",
|
||||
},
|
||||
},
|
||||
@ -208,8 +209,8 @@ func TestNormalized(t *testing.T) {
|
||||
{
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
Tokens: oidctypes.Token{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token1",
|
||||
},
|
||||
},
|
||||
@ -223,8 +224,8 @@ func TestNormalized(t *testing.T) {
|
||||
{
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
Tokens: oidctypes.Token{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token1",
|
||||
},
|
||||
},
|
||||
@ -232,8 +233,8 @@ func TestNormalized(t *testing.T) {
|
||||
{
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now),
|
||||
Tokens: oidcclient.Token{
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
Tokens: oidctypes.Token{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token2",
|
||||
},
|
||||
},
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -65,14 +66,14 @@ type Cache struct {
|
||||
}
|
||||
|
||||
// GetToken looks up the cached data for the given parameters. It may return nil if no valid matching session is cached.
|
||||
func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token {
|
||||
func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidctypes.Token {
|
||||
// If the cache file does not exist, exit immediately with no error log
|
||||
if _, err := os.Stat(c.path); errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the cache and lookup the matching entry. If one exists, update its last used timestamp and return it.
|
||||
var result *oidcclient.Token
|
||||
var result *oidctypes.Token
|
||||
c.withCache(func(cache *sessionCache) {
|
||||
if entry := cache.lookup(key); entry != nil {
|
||||
result = &entry.Tokens
|
||||
@ -84,7 +85,7 @@ func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token {
|
||||
|
||||
// PutToken stores the provided token into the session cache under the given parameters. It does not return an error
|
||||
// but may silently fail to update the session cache.
|
||||
func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidcclient.Token) {
|
||||
func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidctypes.Token) {
|
||||
// Create the cache directory if it does not exist.
|
||||
if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil && !errors.Is(err, os.ErrExist) {
|
||||
c.errReporter(fmt.Errorf("could not create session cache directory: %w", err))
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
@ -38,7 +39,7 @@ func TestGetToken(t *testing.T) {
|
||||
trylockFunc func(*testing.T) error
|
||||
unlockFunc func(*testing.T) error
|
||||
key oidcclient.SessionCacheKey
|
||||
want *oidcclient.Token
|
||||
want *oidctypes.Token
|
||||
wantErrors []string
|
||||
wantTestFile func(t *testing.T, tmp string)
|
||||
}{
|
||||
@ -99,17 +100,17 @@ func TestGetToken(t *testing.T) {
|
||||
},
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -137,17 +138,17 @@ func TestGetToken(t *testing.T) {
|
||||
},
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -161,17 +162,17 @@ func TestGetToken(t *testing.T) {
|
||||
RedirectURI: "http://localhost:0/callback",
|
||||
},
|
||||
wantErrors: []string{},
|
||||
want: &oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
want: &oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -219,7 +220,7 @@ func TestPutToken(t *testing.T) {
|
||||
name string
|
||||
makeTestFile func(t *testing.T, tmp string)
|
||||
key oidcclient.SessionCacheKey
|
||||
token *oidcclient.Token
|
||||
token *oidctypes.Token
|
||||
wantErrors []string
|
||||
wantTestFile func(t *testing.T, tmp string)
|
||||
}{
|
||||
@ -245,17 +246,17 @@ func TestPutToken(t *testing.T) {
|
||||
},
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "old-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "old-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "old-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -269,17 +270,17 @@ func TestPutToken(t *testing.T) {
|
||||
Scopes: []string{"email", "offline_access", "openid", "profile"},
|
||||
RedirectURI: "http://localhost:0/callback",
|
||||
},
|
||||
token: &oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
token: &oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "new-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "new-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "new-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -288,17 +289,17 @@ func TestPutToken(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cache.Sessions, 1)
|
||||
require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
|
||||
require.Equal(t, oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
require.Equal(t, oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "new-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "new-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "new-refresh-token",
|
||||
},
|
||||
}, cache.Sessions[0].Tokens)
|
||||
@ -317,17 +318,17 @@ func TestPutToken(t *testing.T) {
|
||||
},
|
||||
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
|
||||
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
|
||||
Tokens: oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
Tokens: oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "old-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "old-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "old-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -341,17 +342,17 @@ func TestPutToken(t *testing.T) {
|
||||
Scopes: []string{"email", "offline_access", "openid", "profile"},
|
||||
RedirectURI: "http://localhost:0/callback",
|
||||
},
|
||||
token: &oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
token: &oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "new-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "new-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "new-refresh-token",
|
||||
},
|
||||
},
|
||||
@ -360,17 +361,17 @@ func TestPutToken(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, cache.Sessions, 2)
|
||||
require.Less(t, time.Since(cache.Sessions[1].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
|
||||
require.Equal(t, oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
require.Equal(t, oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "new-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "new-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "new-refresh-token",
|
||||
},
|
||||
}, cache.Sessions[1].Tokens)
|
||||
@ -389,17 +390,17 @@ func TestPutToken(t *testing.T) {
|
||||
Scopes: []string{"email", "offline_access", "openid", "profile"},
|
||||
RedirectURI: "http://localhost:0/callback",
|
||||
},
|
||||
token: &oidcclient.Token{
|
||||
AccessToken: &oidcclient.AccessToken{
|
||||
token: &oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "new-access-token",
|
||||
Type: "Bearer",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
IDToken: &oidcclient.IDToken{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "new-id-token",
|
||||
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
|
||||
},
|
||||
RefreshToken: &oidcclient.RefreshToken{
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "new-refresh-token",
|
||||
},
|
||||
},
|
||||
|
@ -16,11 +16,13 @@ import (
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/pkg/browser"
|
||||
"golang.org/x/oauth2"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/httputil/securityheader"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/upstreamoidc"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
"go.pinniped.dev/pkg/oidcclient/state"
|
||||
)
|
||||
@ -51,7 +53,7 @@ type handlerState struct {
|
||||
callbackPath string
|
||||
|
||||
// Generated parameters of a login flow.
|
||||
idTokenVerifier *oidc.IDTokenVerifier
|
||||
provider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
state state.State
|
||||
nonce nonce.Nonce
|
||||
@ -62,13 +64,13 @@ type handlerState struct {
|
||||
generatePKCE func() (pkce.Code, error)
|
||||
generateNonce func() (nonce.Nonce, error)
|
||||
openURL func(string) error
|
||||
oidcDiscover func(context.Context, string) (discoveryI, error)
|
||||
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||
|
||||
callbacks chan callbackResult
|
||||
}
|
||||
|
||||
type callbackResult struct {
|
||||
token *Token
|
||||
token *oidctypes.Token
|
||||
err error
|
||||
}
|
||||
|
||||
@ -87,6 +89,7 @@ func WithContext(ctx context.Context) Option {
|
||||
// WithListenPort specifies a TCP listen port on localhost, which will be used for the redirect_uri and to handle the
|
||||
// authorization code callback. By default, a random high port will be chosen which requires the authorization server
|
||||
// to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252:
|
||||
//
|
||||
// The authorization server MUST allow any port to be specified at the
|
||||
// time of the request for loopback IP redirect URIs, to accommodate
|
||||
// clients that obtain an available ephemeral port from the operating
|
||||
@ -116,6 +119,19 @@ func WithBrowserOpen(openURL func(url string) error) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// SessionCacheKey contains the data used to select a valid session cache entry.
|
||||
type SessionCacheKey struct {
|
||||
Issuer string `json:"issuer"`
|
||||
ClientID string `json:"clientID"`
|
||||
Scopes []string `json:"scopes"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
type SessionCache interface {
|
||||
GetToken(SessionCacheKey) *oidctypes.Token
|
||||
PutToken(SessionCacheKey, *oidctypes.Token)
|
||||
}
|
||||
|
||||
// WithSessionCache sets the session cache backend for storing and retrieving previously-issued ID tokens and refresh tokens.
|
||||
func WithSessionCache(cache SessionCache) Option {
|
||||
return func(h *handlerState) error {
|
||||
@ -135,16 +151,11 @@ func WithClient(httpClient *http.Client) Option {
|
||||
// nopCache is a SessionCache that doesn't actually do anything.
|
||||
type nopCache struct{}
|
||||
|
||||
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
|
||||
func (*nopCache) PutToken(SessionCacheKey, *Token) {}
|
||||
|
||||
type discoveryI interface {
|
||||
Endpoint() oauth2.Endpoint
|
||||
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
||||
}
|
||||
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
|
||||
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
|
||||
|
||||
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
||||
func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
|
||||
h := handlerState{
|
||||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
@ -161,9 +172,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
generateNonce: nonce.Generate,
|
||||
generatePKCE: pkce.Generate,
|
||||
openURL: browser.OpenURL,
|
||||
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
return oidc.NewProvider(ctx, iss)
|
||||
},
|
||||
getProvider: upstreamoidc.New,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt(&h); err != nil {
|
||||
@ -208,16 +217,15 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
}
|
||||
|
||||
// Perform OIDC discovery.
|
||||
discovered, err := h.oidcDiscover(h.ctx, h.issuer)
|
||||
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
||||
}
|
||||
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
|
||||
|
||||
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.clientID,
|
||||
Endpoint: discovered.Endpoint(),
|
||||
Endpoint: h.provider.Endpoint(),
|
||||
Scopes: h.scopes,
|
||||
}
|
||||
|
||||
@ -274,7 +282,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) {
|
||||
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
|
||||
defer cancel()
|
||||
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
|
||||
@ -287,7 +295,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshT
|
||||
|
||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||
return h.validateToken(ctx, refreshed, false)
|
||||
token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
||||
@ -314,58 +326,25 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
||||
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
|
||||
}
|
||||
|
||||
// Exchange the authorization code for access, ID, and refresh tokens.
|
||||
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier())
|
||||
// Exchange the authorization code for access, ID, and refresh tokens and perform required
|
||||
// validations on the returned ID token.
|
||||
token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).
|
||||
ExchangeAuthcodeAndValidateTokens(
|
||||
r.Context(),
|
||||
params.Get("code"),
|
||||
h.pkce,
|
||||
h.nonce,
|
||||
h.oauth2Config.RedirectURL,
|
||||
)
|
||||
if err != nil {
|
||||
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
|
||||
}
|
||||
|
||||
// Perform required validations on the returned ID token.
|
||||
token, err := h.validateToken(r.Context(), oauth2Tok, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.callbacks <- callbackResult{token: token}
|
||||
h.callbacks <- callbackResult{token: &token}
|
||||
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) {
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
if !hasIDTok {
|
||||
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
|
||||
if err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if checkNonce {
|
||||
if err := h.nonce.Validate(validated); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
}
|
||||
}
|
||||
return &Token{
|
||||
AccessToken: &AccessToken{
|
||||
Token: tok.AccessToken,
|
||||
Type: tok.TokenType,
|
||||
Expiry: metav1.NewTime(tok.Expiry),
|
||||
},
|
||||
RefreshToken: &RefreshToken{
|
||||
Token: tok.RefreshToken,
|
||||
},
|
||||
IDToken: &IDToken{
|
||||
Token: idTok,
|
||||
Expiry: metav1.NewTime(validated.Expiry),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handlerState) serve(listener net.Listener) func() {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
||||
|
@ -18,12 +18,14 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/mocks/mockkeyset"
|
||||
"go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
"go.pinniped.dev/pkg/oidcclient/state"
|
||||
)
|
||||
@ -31,19 +33,19 @@ import (
|
||||
// mockSessionCache exists to avoid an import cycle if we generate mocks into another package.
|
||||
type mockSessionCache struct {
|
||||
t *testing.T
|
||||
getReturnsToken *Token
|
||||
getReturnsToken *oidctypes.Token
|
||||
sawGetKeys []SessionCacheKey
|
||||
sawPutKeys []SessionCacheKey
|
||||
sawPutTokens []*Token
|
||||
sawPutTokens []*oidctypes.Token
|
||||
}
|
||||
|
||||
func (m *mockSessionCache) GetToken(key SessionCacheKey) *Token {
|
||||
func (m *mockSessionCache) GetToken(key SessionCacheKey) *oidctypes.Token {
|
||||
m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID)
|
||||
m.sawGetKeys = append(m.sawGetKeys, key)
|
||||
return m.getReturnsToken
|
||||
}
|
||||
|
||||
func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) {
|
||||
func (m *mockSessionCache) PutToken(key SessionCacheKey, token *oidctypes.Token) {
|
||||
m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token)
|
||||
m.sawPutKeys = append(m.sawPutKeys, key)
|
||||
m.sawPutTokens = append(m.sawPutTokens, token)
|
||||
@ -54,20 +56,10 @@ func TestLogin(t *testing.T) {
|
||||
time1Unix := int64(2075807775)
|
||||
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
|
||||
|
||||
testToken := Token{
|
||||
AccessToken: &AccessToken{
|
||||
Token: "test-access-token",
|
||||
Expiry: metav1.NewTime(time1.Add(1 * time.Minute)),
|
||||
},
|
||||
RefreshToken: &RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
IDToken: &IDToken{
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
|
||||
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
|
||||
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
|
||||
},
|
||||
testToken := oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
|
||||
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
|
||||
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
|
||||
}
|
||||
|
||||
// Start a test server that returns 500 errors
|
||||
@ -76,7 +68,7 @@ func TestLogin(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(errorServer.Close)
|
||||
|
||||
// Start a test server that returns a real keyset and answers refresh requests.
|
||||
// Start a test server that returns a real discovery document and answers refresh requests.
|
||||
providerMux := http.NewServeMux()
|
||||
successServer := httptest.NewServer(providerMux)
|
||||
t.Cleanup(successServer.Close)
|
||||
@ -144,7 +136,7 @@ func TestLogin(t *testing.T) {
|
||||
issuer string
|
||||
clientID string
|
||||
wantErr string
|
||||
wantToken *Token
|
||||
wantToken *oidctypes.Token
|
||||
}{
|
||||
{
|
||||
name: "option error",
|
||||
@ -191,8 +183,8 @@ func TestLogin(t *testing.T) {
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "test-id-token",
|
||||
Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
@ -246,12 +238,20 @@ func TestLogin(t *testing.T) {
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(testToken, nil, nil)
|
||||
return mock
|
||||
}
|
||||
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
|
||||
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
|
||||
}}
|
||||
t.Cleanup(func() {
|
||||
cacheKey := SessionCacheKey{
|
||||
@ -266,12 +266,6 @@ func TestLogin(t *testing.T) {
|
||||
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
@ -283,12 +277,20 @@ func TestLogin(t *testing.T) {
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error"))
|
||||
return mock
|
||||
}
|
||||
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
|
||||
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
|
||||
}}
|
||||
t.Cleanup(func() {
|
||||
require.Empty(t, cache.sawPutKeys)
|
||||
@ -296,16 +298,10 @@ func TestLogin(t *testing.T) {
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
},
|
||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||
wantErr: "some validation error",
|
||||
},
|
||||
{
|
||||
name: "session cache hit but refresh fails",
|
||||
@ -313,12 +309,12 @@ func TestLogin(t *testing.T) {
|
||||
clientID: "not-the-test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
|
||||
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
|
||||
}}
|
||||
t.Cleanup(func() {
|
||||
require.Empty(t, cache.sawPutKeys)
|
||||
@ -326,12 +322,6 @@ func TestLogin(t *testing.T) {
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
|
||||
h.listenAddr = "invalid-listen-address"
|
||||
|
||||
return nil
|
||||
@ -413,7 +403,7 @@ func TestLogin(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
|
||||
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
|
||||
require.Equal(t, []*Token{&testToken}, cache.sawPutTokens)
|
||||
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
|
||||
})
|
||||
require.NoError(t, WithSessionCache(cache)(h))
|
||||
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
|
||||
@ -481,7 +471,7 @@ func TestLogin(t *testing.T) {
|
||||
require.NotNil(t, tok.AccessToken)
|
||||
require.Equal(t, want.Token, tok.AccessToken.Token)
|
||||
require.Equal(t, want.Type, tok.AccessToken.Type)
|
||||
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
|
||||
testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
|
||||
} else {
|
||||
assert.Nil(t, tok.AccessToken)
|
||||
}
|
||||
@ -489,7 +479,7 @@ func TestLogin(t *testing.T) {
|
||||
if want := tt.wantToken.IDToken; want != nil {
|
||||
require.NotNil(t, tok.IDToken)
|
||||
require.Equal(t, want.Token, tok.IDToken.Token)
|
||||
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
|
||||
testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
|
||||
} else {
|
||||
assert.Nil(t, tok.IDToken)
|
||||
}
|
||||
@ -498,11 +488,13 @@ func TestLogin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
const testRedirectURI = "http://127.0.0.1:12324/callback"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
query string
|
||||
returnIDTok string
|
||||
opt func(t *testing.T) Option
|
||||
wantErr string
|
||||
wantHTTPStatus int
|
||||
}{
|
||||
@ -528,94 +520,51 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
{
|
||||
name: "invalid code",
|
||||
query: "state=test-state&code=invalid",
|
||||
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n",
|
||||
wantErr: "could not complete code exchange: some exchange error",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
|
||||
Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
|
||||
return mock
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
{
|
||||
name: "missing ID token",
|
||||
query: "state=test-state&code=valid",
|
||||
returnIDTok: "",
|
||||
wantErr: "received response missing ID token",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid ID token",
|
||||
query: "state=test-state&code=valid",
|
||||
returnIDTok: "invalid-jwt",
|
||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid access token hash",
|
||||
query: "state=test-state&code=valid",
|
||||
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
|
||||
|
||||
wantErr: "received invalid ID token: access token hash does not match value in ID token",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid nonce",
|
||||
query: "state=test-state&code=valid",
|
||||
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
|
||||
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
query: "state=test-state&code=valid",
|
||||
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
|
||||
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
|
||||
return mock
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.NoError(t, r.ParseForm())
|
||||
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
||||
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
|
||||
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
|
||||
require.NotEmpty(t, r.Form.Get("code"))
|
||||
if r.Form.Get("code") != "valid" {
|
||||
http.Error(w, "invalid authorization code", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
var response struct {
|
||||
oauth2.Token
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
response.AccessToken = "test-access-token"
|
||||
response.Expiry = time.Now().Add(time.Hour)
|
||||
response.IDToken = tt.returnIDTok
|
||||
w.Header().Set("content-type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||
}))
|
||||
t.Cleanup(tokenServer.Close)
|
||||
|
||||
h := &handlerState{
|
||||
callbacks: make(chan callbackResult, 1),
|
||||
state: state.State("test-state"),
|
||||
pkce: pkce.Code("test-pkce"),
|
||||
nonce: nonce.Nonce("test-nonce"),
|
||||
oauth2Config: &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
RedirectURL: "http://localhost:12345/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
TokenURL: tokenServer.URL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
},
|
||||
idTokenVerifier: mockVerifier(),
|
||||
}
|
||||
if tt.opt != nil {
|
||||
require.NoError(t, tt.opt(t)(h))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@ -651,47 +600,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, result.err)
|
||||
require.NotNil(t, result.token)
|
||||
require.Equal(t, result.token.IDToken.Token, tt.returnIDTok)
|
||||
require.Equal(t, result.token.IDToken.Token, "test-id-token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else.
|
||||
func mockVerifier() *oidc.IDTokenVerifier {
|
||||
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil))
|
||||
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().
|
||||
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
|
||||
jws, err := jose.ParseSigned(jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.UnsafePayloadWithoutVerification(), nil
|
||||
})
|
||||
|
||||
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
|
||||
SkipIssuerCheck: true,
|
||||
SkipExpiryCheck: true,
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
|
||||
}
|
||||
|
||||
type mockDiscovery struct{ provider *oidc.Provider }
|
||||
// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
|
||||
type hasAccessTokenMatcher struct{ expected string }
|
||||
|
||||
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
|
||||
|
||||
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
|
||||
|
||||
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
|
||||
require.InDeltaf(t,
|
||||
float64(t1.UnixNano()),
|
||||
float64(t2.UnixNano()),
|
||||
float64(delta.Nanoseconds()),
|
||||
"expected %s and %s to be < %s apart, but they are %s apart",
|
||||
t1.Format(time.RFC3339Nano),
|
||||
t2.Format(time.RFC3339Nano),
|
||||
delta.String(),
|
||||
t1.Sub(t2).String(),
|
||||
)
|
||||
func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
|
||||
return arg.(*oauth2.Token).AccessToken == m.expected
|
||||
}
|
||||
|
||||
func (m hasAccessTokenMatcher) Got(got interface{}) string {
|
||||
return got.(*oauth2.Token).AccessToken
|
||||
}
|
||||
|
||||
func (m hasAccessTokenMatcher) String() string {
|
||||
return m.expected
|
||||
}
|
||||
|
||||
func HasAccessToken(expected string) gomock.Matcher {
|
||||
return hasAccessTokenMatcher{expected: expected}
|
||||
}
|
||||
|
@ -1,11 +1,10 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package oidcclient
|
||||
// Package oidctypes provides core data types for OIDC token structures.
|
||||
package oidctypes
|
||||
|
||||
import (
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
)
|
||||
import v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
// AccessToken is an OAuth2 access token.
|
||||
type AccessToken struct {
|
||||
@ -16,7 +15,7 @@ type AccessToken struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
|
||||
// Expiry is the optional expiration time of the access token.
|
||||
Expiry metav1.Time `json:"expiryTimestamp,omitempty"`
|
||||
Expiry v1.Time `json:"expiryTimestamp,omitempty"`
|
||||
}
|
||||
|
||||
// RefreshToken is an OAuth2 refresh token.
|
||||
@ -31,7 +30,7 @@ type IDToken struct {
|
||||
Token string `json:"token"`
|
||||
|
||||
// Expiry is the optional expiration time of the ID token.
|
||||
Expiry metav1.Time `json:"expiryTimestamp,omitempty"`
|
||||
Expiry v1.Time `json:"expiryTimestamp,omitempty"`
|
||||
}
|
||||
|
||||
// Token contains the elements of an OIDC session.
|
||||
@ -47,16 +46,3 @@ type Token struct {
|
||||
// IDToken is an OpenID Connect ID token.
|
||||
IDToken *IDToken `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
// SessionCacheKey contains the data used to select a valid session cache entry.
|
||||
type SessionCacheKey struct {
|
||||
Issuer string `json:"issuer"`
|
||||
ClientID string `json:"clientID"`
|
||||
Scopes []string `json:"scopes"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
type SessionCache interface {
|
||||
GetToken(SessionCacheKey) *Token
|
||||
PutToken(SessionCacheKey, *Token)
|
||||
}
|
@ -28,8 +28,7 @@ staticClients:
|
||||
name: 'Pinniped Supervisor'
|
||||
secret: pinniped-supervisor-secret
|
||||
redirectURIs:
|
||||
- #@ "http://127.0.0.1:" + str(data.values.ports.cli) + "/callback"
|
||||
- #@ "http://[::1]:" + str(data.values.ports.cli) + "/callback"
|
||||
- https://pinniped-supervisor-clusterip.supervisor.svc.cluster.local/some/path/callback
|
||||
enablePasswordDB: true
|
||||
staticPasswords:
|
||||
- username: "pinny"
|
||||
|
@ -20,7 +20,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sclevine/agouti"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
@ -30,6 +29,7 @@ import (
|
||||
"go.pinniped.dev/pkg/oidcclient"
|
||||
"go.pinniped.dev/pkg/oidcclient/filesession"
|
||||
"go.pinniped.dev/test/library"
|
||||
"go.pinniped.dev/test/library/browsertest"
|
||||
)
|
||||
|
||||
func TestCLIGetKubeconfig(t *testing.T) {
|
||||
@ -108,80 +108,14 @@ func runPinnipedCLIGetKubeconfig(t *testing.T, pinnipedExe, token, namespaceName
|
||||
return string(output)
|
||||
}
|
||||
|
||||
type loginProviderPatterns struct {
|
||||
Name string
|
||||
IssuerPattern *regexp.Regexp
|
||||
LoginPagePattern *regexp.Regexp
|
||||
UsernameSelector string
|
||||
PasswordSelector string
|
||||
LoginButtonSelector string
|
||||
}
|
||||
|
||||
func getLoginProvider(t *testing.T) *loginProviderPatterns {
|
||||
t.Helper()
|
||||
issuer := library.IntegrationEnv(t).CLITestUpstream.Issuer
|
||||
for _, p := range []loginProviderPatterns{
|
||||
{
|
||||
Name: "Okta",
|
||||
IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
|
||||
LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
|
||||
UsernameSelector: "input#okta-signin-username",
|
||||
PasswordSelector: "input#okta-signin-password",
|
||||
LoginButtonSelector: "input#okta-signin-submit",
|
||||
},
|
||||
{
|
||||
Name: "Dex",
|
||||
IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`),
|
||||
LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`),
|
||||
UsernameSelector: "input#login",
|
||||
PasswordSelector: "input#password",
|
||||
LoginButtonSelector: "button#submit-login",
|
||||
},
|
||||
} {
|
||||
if p.IssuerPattern.MatchString(issuer) {
|
||||
return &p
|
||||
}
|
||||
}
|
||||
require.Failf(t, "could not find login provider for issuer %q", issuer)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCLILoginOIDC(t *testing.T) {
|
||||
env := library.IntegrationEnv(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Find the login CSS selectors for the test issuer, or fail fast.
|
||||
loginProvider := getLoginProvider(t)
|
||||
|
||||
// Start the browser driver.
|
||||
t.Logf("opening browser driver")
|
||||
caps := agouti.NewCapabilities()
|
||||
if env.Proxy != "" {
|
||||
t.Logf("configuring Chrome to use proxy %q", env.Proxy)
|
||||
caps = caps.Proxy(agouti.ProxyConfig{
|
||||
ProxyType: "manual",
|
||||
HTTPProxy: env.Proxy,
|
||||
SSLProxy: env.Proxy,
|
||||
NoProxy: "127.0.0.1",
|
||||
})
|
||||
}
|
||||
agoutiDriver := agouti.ChromeDriver(
|
||||
agouti.Desired(caps),
|
||||
agouti.ChromeOptions("args", []string{
|
||||
"--no-sandbox",
|
||||
"--ignore-certificate-errors",
|
||||
"--headless", // Comment out this line to see the tests happen in a visible browser window.
|
||||
}),
|
||||
// Uncomment this to see stdout/stderr from chromedriver.
|
||||
// agouti.Debug,
|
||||
)
|
||||
require.NoError(t, agoutiDriver.Start())
|
||||
t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) })
|
||||
page, err := agoutiDriver.NewPage(agouti.Browser("chrome"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, page.Reset())
|
||||
page := browsertest.Open(t)
|
||||
|
||||
// Build pinniped CLI.
|
||||
t.Logf("building CLI binary")
|
||||
@ -262,28 +196,18 @@ func TestCLILoginOIDC(t *testing.T) {
|
||||
t.Logf("navigating to login page")
|
||||
require.NoError(t, page.Navigate(loginURL))
|
||||
|
||||
// Expect to be redirected to the login page.
|
||||
t.Logf("waiting for redirect to %s login page", loginProvider.Name)
|
||||
waitForURL(t, page, loginProvider.LoginPagePattern)
|
||||
// Expect to be redirected to the upstream provider and log in.
|
||||
browsertest.LoginToUpstream(t, page, env.CLITestUpstream)
|
||||
|
||||
// Wait for the login page to be rendered.
|
||||
waitForVisibleElements(t, page, loginProvider.UsernameSelector, loginProvider.PasswordSelector, loginProvider.LoginButtonSelector)
|
||||
|
||||
// Fill in the username and password and click "submit".
|
||||
t.Logf("logging into %s", loginProvider.Name)
|
||||
require.NoError(t, page.First(loginProvider.UsernameSelector).Fill(env.CLITestUpstream.Username))
|
||||
require.NoError(t, page.First(loginProvider.PasswordSelector).Fill(env.CLITestUpstream.Password))
|
||||
require.NoError(t, page.First(loginProvider.LoginButtonSelector).Click())
|
||||
|
||||
// Wait for the login to happen and us be redirected back to a localhost callback.
|
||||
t.Logf("waiting for redirect to localhost callback")
|
||||
// Expect to be redirected to the localhost callback.
|
||||
t.Logf("waiting for redirect to callback")
|
||||
callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(env.CLITestUpstream.CallbackURL) + `\?.+\z`)
|
||||
waitForURL(t, page, callbackURLPattern)
|
||||
browsertest.WaitForURL(t, page, callbackURLPattern)
|
||||
|
||||
// Wait for the "pre" element that gets rendered for a `text/plain` page, and
|
||||
// assert that it contains the success message.
|
||||
t.Logf("verifying success page")
|
||||
waitForVisibleElements(t, page, "pre")
|
||||
browsertest.WaitForVisibleElements(t, page, "pre")
|
||||
msg, err := page.First("pre").Text()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "you have been logged in and may now close this tab", msg)
|
||||
@ -361,44 +285,6 @@ func TestCLILoginOIDC(t *testing.T) {
|
||||
require.NotEqual(t, credOutput2.Status.Token, credOutput3.Status.Token)
|
||||
}
|
||||
|
||||
func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) {
|
||||
t.Helper()
|
||||
require.Eventually(t,
|
||||
func() bool {
|
||||
for _, sel := range selectors {
|
||||
vis, err := page.First(sel).Visible()
|
||||
if !(err == nil && vis) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
10*time.Second,
|
||||
100*time.Millisecond,
|
||||
)
|
||||
}
|
||||
|
||||
func waitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) {
|
||||
var lastURL string
|
||||
require.Eventuallyf(t,
|
||||
func() bool {
|
||||
url, err := page.URL()
|
||||
if err == nil && pat.MatchString(url) {
|
||||
return true
|
||||
}
|
||||
if url != lastURL {
|
||||
t.Logf("saw URL %s", url)
|
||||
lastURL = url
|
||||
}
|
||||
return false
|
||||
},
|
||||
10*time.Second,
|
||||
100*time.Millisecond,
|
||||
"expected to browse to %s, but never got there",
|
||||
pat,
|
||||
)
|
||||
}
|
||||
|
||||
func readAndExpectEmpty(r io.Reader) (err error) {
|
||||
var remainder bytes.Buffer
|
||||
_, err = io.Copy(&remainder, r)
|
||||
|
@ -17,7 +17,7 @@ import (
|
||||
"k8s.io/apimachinery/pkg/api/errors"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/fosite/authorizationcode"
|
||||
"go.pinniped.dev/internal/fositestorage/authorizationcode"
|
||||
"go.pinniped.dev/test/library"
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ func TestAuthorizeCodeStorage(t *testing.T) {
|
||||
// randomly generated HMAC authorization code (see below)
|
||||
code = "TQ72B8YjdEOZyxridYbTLE-pzoK4hpdkZxym5j4EmSc.TKRTgQG41IBQ16FDKTthRdhXfLlNaErcMd9Fy47uXAw"
|
||||
// name of the secret that will be created in Kube
|
||||
name = "pinniped-storage-authorization-codes-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga"
|
||||
name = "pinniped-storage-authcode-jssfhaibxdkiaugxufbsso3bixmfo7fzjvuevxbr35c4xdxolqga"
|
||||
)
|
||||
|
||||
hmac := compose.NewOAuth2HMACStrategy(&compose.Config{}, []byte("super-secret-32-byte-for-testing"), nil)
|
||||
|
@ -111,7 +111,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) {
|
||||
|
||||
// When the same issuer is added twice, both issuers are marked as duplicates, and neither provider is serving.
|
||||
config6Duplicate1, _ := requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(ctx, t, scheme, addr, caBundle, issuer6, client)
|
||||
config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "")
|
||||
config6Duplicate2 := library.CreateTestOIDCProvider(ctx, t, issuer6, "", "")
|
||||
requireStatus(t, client, ns, config6Duplicate1.Name, v1alpha1.DuplicateOIDCProviderStatusCondition)
|
||||
requireStatus(t, client, ns, config6Duplicate2.Name, v1alpha1.DuplicateOIDCProviderStatusCondition)
|
||||
requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, issuer6)
|
||||
@ -136,7 +136,7 @@ func TestSupervisorOIDCDiscovery(t *testing.T) {
|
||||
}
|
||||
|
||||
// When we create a provider with an invalid issuer, the status is set to invalid.
|
||||
badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "")
|
||||
badConfig := library.CreateTestOIDCProvider(ctx, t, badIssuer, "", "")
|
||||
requireStatus(t, client, ns, badConfig.Name, v1alpha1.InvalidOIDCProviderStatusCondition)
|
||||
requireDiscoveryEndpointsAreNotFound(t, scheme, addr, caBundle, badIssuer)
|
||||
requireDeletingOIDCProviderCausesDiscoveryEndpointsToDisappear(t, badConfig, client, ns, scheme, addr, caBundle, badIssuer)
|
||||
@ -162,7 +162,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) {
|
||||
certSecretName1 := "integration-test-cert-1"
|
||||
|
||||
// Create an OIDCProvider with a spec.tls.secretName.
|
||||
oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1)
|
||||
oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuer1, certSecretName1, "")
|
||||
requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
|
||||
|
||||
// The spec.tls.secretName Secret does not exist, so the endpoints should fail with TLS errors.
|
||||
@ -198,7 +198,7 @@ func TestSupervisorTLSTerminationWithSNI(t *testing.T) {
|
||||
certSecretName2 := "integration-test-cert-2"
|
||||
|
||||
// Create an OIDCProvider with a spec.tls.secretName.
|
||||
oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2)
|
||||
oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuer2, certSecretName2, "")
|
||||
requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
|
||||
|
||||
// Create the Secret.
|
||||
@ -232,31 +232,30 @@ func TestSupervisorTLSTerminationWithDefaultCerts(t *testing.T) {
|
||||
port = hostAndPortSegments[1]
|
||||
}
|
||||
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
|
||||
ips, err := library.LookupIP(ctx, hostname)
|
||||
require.NoError(t, err)
|
||||
ip := ips[0]
|
||||
ipAsString := ip.String()
|
||||
ipWithPort := ipAsString + ":" + port
|
||||
require.NotEmpty(t, ips)
|
||||
ipWithPort := ips[0].String() + ":" + port
|
||||
|
||||
issuerUsingIPAddress := fmt.Sprintf("%s://%s/issuer1", scheme, ipWithPort)
|
||||
issuerUsingHostname := fmt.Sprintf("%s://%s/issuer1", scheme, address)
|
||||
|
||||
// Create an OIDCProvider without a spec.tls.secretName.
|
||||
oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "")
|
||||
oidcProvider1 := library.CreateTestOIDCProvider(ctx, t, issuerUsingIPAddress, "", "")
|
||||
requireStatus(t, pinnipedClient, oidcProvider1.Namespace, oidcProvider1.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
|
||||
|
||||
// There is no default TLS cert and the spec.tls.secretName was not set, so the endpoints should fail with TLS errors.
|
||||
requireEndpointHasTLSErrorBecauseCertificatesAreNotReady(t, issuerUsingIPAddress)
|
||||
|
||||
// Create a Secret at the special name which represents the default TLS cert.
|
||||
defaultCA := createTLSCertificateSecret(ctx, t, ns, "cert-hostname-doesnt-matter", []net.IP{ip.IP}, defaultTLSCertSecretName(env), kubeClient)
|
||||
defaultCA := createTLSCertificateSecret(ctx, t, ns, "cert-hostname-doesnt-matter", []net.IP{ips[0]}, defaultTLSCertSecretName(env), kubeClient)
|
||||
|
||||
// Now that the Secret exists, we should be able to access the endpoints by IP address using the CA.
|
||||
_ = requireDiscoveryEndpointsAreWorking(t, scheme, ipWithPort, string(defaultCA.Bundle()), issuerUsingIPAddress, nil)
|
||||
|
||||
// Create an OIDCProvider with a spec.tls.secretName.
|
||||
certSecretName := "integration-test-cert-1"
|
||||
oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName)
|
||||
oidcProvider2 := library.CreateTestOIDCProvider(ctx, t, issuerUsingHostname, certSecretName, "")
|
||||
requireStatus(t, pinnipedClient, oidcProvider2.Namespace, oidcProvider2.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
|
||||
|
||||
// Create the Secret.
|
||||
@ -429,7 +428,7 @@ func requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(
|
||||
client pinnipedclientset.Interface,
|
||||
) (*v1alpha1.OIDCProvider, *ExpectedJWKSResponseFormat) {
|
||||
t.Helper()
|
||||
newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "")
|
||||
newOIDCProvider := library.CreateTestOIDCProvider(ctx, t, issuerName, "", "")
|
||||
jwksResult := requireDiscoveryEndpointsAreWorking(t, supervisorScheme, supervisorAddress, supervisorCABundle, issuerName, nil)
|
||||
requireStatus(t, client, newOIDCProvider.Namespace, newOIDCProvider.Name, v1alpha1.SuccessOIDCProviderStatusCondition)
|
||||
return newOIDCProvider, jwksResult
|
||||
|
@ -27,7 +27,7 @@ func TestSupervisorOIDCKeys(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
// Create our OPC under test.
|
||||
opc := library.CreateTestOIDCProvider(ctx, t, "", "")
|
||||
opc := library.CreateTestOIDCProvider(ctx, t, "", "", "")
|
||||
|
||||
// Ensure a secret is created with the OPC's JWKS.
|
||||
var updatedOPC *configv1alpha1.OIDCProvider
|
||||
|
@ -6,236 +6,180 @@ package integration
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1"
|
||||
idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1"
|
||||
"go.pinniped.dev/internal/certauthority"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
"go.pinniped.dev/pkg/oidcclient/state"
|
||||
"go.pinniped.dev/test/library"
|
||||
"go.pinniped.dev/test/library/browsertest"
|
||||
)
|
||||
|
||||
func TestSupervisorLogin(t *testing.T) {
|
||||
t.Skip("waiting on new callback path logic to get merged in from the callback endpoint work")
|
||||
|
||||
env := library.IntegrationEnv(t)
|
||||
client := library.NewSupervisorClientset(t)
|
||||
|
||||
// If anything in this test crashes, dump out the supervisor pod logs.
|
||||
defer library.DumpLogs(t, env.SupervisorNamespace)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
Scheme string
|
||||
Address string
|
||||
CABundle string
|
||||
}{
|
||||
{Scheme: "http", Address: env.SupervisorHTTPAddress},
|
||||
{Scheme: "https", Address: env.SupervisorHTTPSIngressAddress, CABundle: env.SupervisorHTTPSIngressCABundle},
|
||||
// Infer the downstream issuer URL from the callback associated with the upstream test client registration.
|
||||
issuerURL, err := url.Parse(env.SupervisorTestUpstream.CallbackURL)
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.HasSuffix(issuerURL.Path, "/callback"))
|
||||
issuerURL.Path = strings.TrimSuffix(issuerURL.Path, "/callback")
|
||||
t.Logf("testing with downstream issuer URL %s", issuerURL.String())
|
||||
|
||||
// Generate a CA bundle with which to serve this provider.
|
||||
t.Logf("generating test CA")
|
||||
ca, err := certauthority.New(pkix.Name{CommonName: "Downstream Test CA"}, 1*time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an HTTP client that can reach the downstream discovery endpoint using the CA certs.
|
||||
httpClient := &http.Client{Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{RootCAs: ca.Pool()},
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
if env.Proxy == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return url.Parse(env.Proxy)
|
||||
},
|
||||
}}
|
||||
|
||||
for _, test := range tests {
|
||||
scheme := test.Scheme
|
||||
addr := test.Address
|
||||
caBundle := test.CABundle
|
||||
// Use the CA to issue a TLS server cert.
|
||||
t.Logf("issuing test certificate")
|
||||
tlsCert, err := ca.Issue(
|
||||
pkix.Name{CommonName: issuerURL.Hostname()},
|
||||
[]string{issuerURL.Hostname()},
|
||||
nil,
|
||||
1*time.Hour,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
certPEM, keyPEM, err := certauthority.ToPEM(tlsCert)
|
||||
require.NoError(t, err)
|
||||
|
||||
if addr == "" {
|
||||
// Both cases are not required, so when one is empty skip it.
|
||||
continue
|
||||
}
|
||||
|
||||
// Create downstream OIDC provider (i.e., update supervisor with OIDC provider).
|
||||
path := getDownstreamIssuerPathFromUpstreamRedirectURI(t, env.SupervisorTestUpstream.CallbackURL)
|
||||
issuer := fmt.Sprintf("https://%s%s", addr, path)
|
||||
_, _ = requireCreatingOIDCProviderCausesDiscoveryEndpointsToAppear(
|
||||
ctx,
|
||||
t,
|
||||
scheme,
|
||||
addr,
|
||||
caBundle,
|
||||
issuer,
|
||||
client,
|
||||
// Write the serving cert to a secret.
|
||||
certSecret := library.CreateTestSecret(t,
|
||||
env.SupervisorNamespace,
|
||||
"oidc-provider-tls",
|
||||
"kubernetes.io/tls",
|
||||
map[string]string{"tls.crt": string(certPEM), "tls.key": string(keyPEM)},
|
||||
)
|
||||
|
||||
// Create HTTP client.
|
||||
httpClient := newHTTPClient(t, caBundle, nil)
|
||||
httpClient.CheckRedirect = func(_ *http.Request, _ []*http.Request) error {
|
||||
// Don't follow any redirects right now, since we simply want to validate that our auth endpoint
|
||||
// redirects us.
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
// Create the downstream OIDCProvider and expect it to go into the success status condition.
|
||||
downstream := library.CreateTestOIDCProvider(ctx, t,
|
||||
issuerURL.String(),
|
||||
certSecret.Name,
|
||||
configv1alpha1.SuccessOIDCProviderStatusCondition,
|
||||
)
|
||||
|
||||
// Declare the downstream auth endpoint url we will use.
|
||||
downstreamAuthURL := makeDownstreamAuthURL(t, scheme, addr, path)
|
||||
|
||||
// Make request to auth endpoint - should fail, since we have no upstreams.
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil)
|
||||
require.NoError(t, err)
|
||||
rsp, err := httpClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer rsp.Body.Close()
|
||||
require.Equal(t, http.StatusUnprocessableEntity, rsp.StatusCode)
|
||||
|
||||
// Create upstream OIDC provider.
|
||||
spec := idpv1alpha1.UpstreamOIDCProviderSpec{
|
||||
// Create upstream OIDC provider and wait for it to become ready.
|
||||
library.CreateTestUpstreamOIDCProvider(t, idpv1alpha1.UpstreamOIDCProviderSpec{
|
||||
Issuer: env.SupervisorTestUpstream.Issuer,
|
||||
TLS: &idpv1alpha1.TLSSpec{
|
||||
CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorTestUpstream.CABundle)),
|
||||
},
|
||||
Client: idpv1alpha1.OIDCClient{
|
||||
SecretName: makeTestClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name,
|
||||
SecretName: library.CreateClientCredsSecret(t, env.SupervisorTestUpstream.ClientID, env.SupervisorTestUpstream.ClientSecret).Name,
|
||||
},
|
||||
}
|
||||
upstream := makeTestUpstream(t, spec, idpv1alpha1.PhaseReady)
|
||||
}, idpv1alpha1.PhaseReady)
|
||||
|
||||
// Make request to authorize endpoint - should pass, since we now have an upstream.
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthURL, nil)
|
||||
require.NoError(t, err)
|
||||
rsp, err = httpClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer rsp.Body.Close()
|
||||
require.Equal(t, http.StatusFound, rsp.StatusCode)
|
||||
requireValidRedirectLocation(
|
||||
ctx,
|
||||
t,
|
||||
upstream.Spec.Issuer,
|
||||
env.SupervisorTestUpstream.ClientID,
|
||||
env.SupervisorTestUpstream.CallbackURL,
|
||||
rsp.Header.Get("Location"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func getDownstreamIssuerPathFromUpstreamRedirectURI(t *testing.T, upstreamRedirectURI string) string {
|
||||
// We need to construct the downstream issuer path from the upstream redirect URI since the two
|
||||
// are related, and the upstream redirect URI is supplied via a static test environment
|
||||
// variable. The upstream redirect URI should be something like
|
||||
// https://supervisor.com/some/supervisor/path/callback
|
||||
// and therefore the downstream issuer should be something like
|
||||
// https://supervisor.com/some/supervisor/path
|
||||
// since the /callback endpoint is placed at the root of the downstream issuer path.
|
||||
upstreamRedirectURL, err := url.Parse(upstreamRedirectURI)
|
||||
// Perform OIDC discovery for our downstream.
|
||||
var discovery *oidc.Provider
|
||||
assert.Eventually(t, func() bool {
|
||||
discovery, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), downstream.Spec.Issuer)
|
||||
return err == nil
|
||||
}, 60*time.Second, 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
redirectURIPathWithoutLastSegment, lastUpstreamRedirectURIPathSegment := path.Split(upstreamRedirectURL.Path)
|
||||
require.Equalf(
|
||||
t,
|
||||
"callback",
|
||||
lastUpstreamRedirectURIPathSegment,
|
||||
"expected upstream redirect URI (%q) to follow supervisor callback path conventions (i.e., end in /callback)",
|
||||
upstreamRedirectURI,
|
||||
)
|
||||
// Start a callback server on localhost.
|
||||
localCallbackServer := startLocalCallbackServer(t)
|
||||
|
||||
if strings.HasSuffix(redirectURIPathWithoutLastSegment, "/") {
|
||||
redirectURIPathWithoutLastSegment = redirectURIPathWithoutLastSegment[:len(redirectURIPathWithoutLastSegment)-1]
|
||||
}
|
||||
|
||||
return redirectURIPathWithoutLastSegment
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func makeDownstreamAuthURL(t *testing.T, scheme, addr, path string) string {
|
||||
t.Helper()
|
||||
// Form the OAuth2 configuration corresponding to our CLI client.
|
||||
downstreamOAuth2Config := oauth2.Config{
|
||||
// This is the hardcoded public client that the supervisor supports.
|
||||
ClientID: "pinniped-cli",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf("%s://%s%s/oauth2/authorize", scheme, addr, path),
|
||||
},
|
||||
// This is the hardcoded downstream redirect URI that the supervisor supports.
|
||||
RedirectURL: "http://127.0.0.1/callback",
|
||||
Endpoint: discovery.Endpoint(),
|
||||
RedirectURL: localCallbackServer.URL,
|
||||
Scopes: []string{"openid"},
|
||||
}
|
||||
state, nonce, pkce := generateAuthRequestParams(t)
|
||||
return downstreamOAuth2Config.AuthCodeURL(
|
||||
state.String(),
|
||||
nonce.Param(),
|
||||
pkce.Challenge(),
|
||||
pkce.Method(),
|
||||
|
||||
// Build a valid downstream authorize URL for the supervisor.
|
||||
stateParam, err := state.Generate()
|
||||
require.NoError(t, err)
|
||||
nonceParam, err := nonce.Generate()
|
||||
require.NoError(t, err)
|
||||
pkceParam, err := pkce.Generate()
|
||||
require.NoError(t, err)
|
||||
downstreamAuthorizeURL := downstreamOAuth2Config.AuthCodeURL(
|
||||
stateParam.String(),
|
||||
nonceParam.Param(),
|
||||
pkceParam.Challenge(),
|
||||
pkceParam.Method(),
|
||||
)
|
||||
|
||||
// Open the web browser and navigate to the downstream authorize URL.
|
||||
page := browsertest.Open(t)
|
||||
t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL))
|
||||
require.NoError(t, page.Navigate(downstreamAuthorizeURL))
|
||||
|
||||
// Expect to be redirected to the upstream provider and log in.
|
||||
browsertest.LoginToUpstream(t, page, env.SupervisorTestUpstream)
|
||||
|
||||
// Wait for the login to happen and us be redirected back to a localhost callback.
|
||||
t.Logf("waiting for redirect to callback")
|
||||
callbackURLPattern := regexp.MustCompile(`\A` + regexp.QuoteMeta(localCallbackServer.URL) + `\?.+\z`)
|
||||
browsertest.WaitForURL(t, page, callbackURLPattern)
|
||||
|
||||
// Expect that our callback handler was invoked.
|
||||
callback := localCallbackServer.waitForCallback(10 * time.Second)
|
||||
t.Logf("got callback request: %s", library.MaskTokens(callback.URL.String()))
|
||||
require.Equal(t, stateParam.String(), callback.URL.Query().Get("state"))
|
||||
require.Equal(t, "openid", callback.URL.Query().Get("scope"))
|
||||
require.NotEmpty(t, callback.URL.Query().Get("code"))
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func generateAuthRequestParams(t *testing.T) (state.State, nonce.Nonce, pkce.Code) {
|
||||
t.Helper()
|
||||
state, err := state.Generate()
|
||||
require.NoError(t, err)
|
||||
nonce, err := nonce.Generate()
|
||||
require.NoError(t, err)
|
||||
pkce, err := pkce.Generate()
|
||||
require.NoError(t, err)
|
||||
return state, nonce, pkce
|
||||
func startLocalCallbackServer(t *testing.T) *localCallbackServer {
|
||||
// Handle the callback by sending the *http.Request object back through a channel.
|
||||
callbacks := make(chan *http.Request, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callbacks <- r
|
||||
}))
|
||||
server.URL += "/callback"
|
||||
t.Cleanup(server.Close)
|
||||
t.Cleanup(func() { close(callbacks) })
|
||||
return &localCallbackServer{Server: server, t: t, callbacks: callbacks}
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func requireValidRedirectLocation(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
issuer, clientID, redirectURI, actualLocation string,
|
||||
) {
|
||||
t.Helper()
|
||||
env := library.IntegrationEnv(t)
|
||||
|
||||
// Do OIDC discovery on our test issuer to get auth endpoint.
|
||||
transport := http.Transport{}
|
||||
if env.Proxy != "" {
|
||||
transport.Proxy = func(_ *http.Request) (*url.URL, error) {
|
||||
return url.Parse(env.Proxy)
|
||||
}
|
||||
}
|
||||
if env.SupervisorTestUpstream.CABundle != "" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: x509.NewCertPool()}
|
||||
transport.TLSClientConfig.RootCAs.AppendCertsFromPEM([]byte(env.SupervisorTestUpstream.CABundle))
|
||||
type localCallbackServer struct {
|
||||
*httptest.Server
|
||||
t *testing.T
|
||||
callbacks <-chan *http.Request
|
||||
}
|
||||
|
||||
ctx = oidc.ClientContext(ctx, &http.Client{Transport: &transport})
|
||||
upstreamProvider, err := oidc.NewProvider(ctx, issuer)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse expected upstream auth URL.
|
||||
expectedLocationURL, err := url.Parse(
|
||||
(&oauth2.Config{
|
||||
ClientID: clientID,
|
||||
Endpoint: upstreamProvider.Endpoint(),
|
||||
RedirectURL: redirectURI,
|
||||
Scopes: []string{"openid"},
|
||||
}).AuthCodeURL("", oauth2.AccessTypeOffline),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse actual upstream auth URL.
|
||||
actualLocationURL, err := url.Parse(actualLocation)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First make some assertions on the query values. Note that we will not be able to know what
|
||||
// certain query values are since they may be random (e.g., state, pkce, nonce).
|
||||
expectedLocationQuery := expectedLocationURL.Query()
|
||||
actualLocationQuery := actualLocationURL.Query()
|
||||
require.NotEmpty(t, actualLocationQuery.Get("state"))
|
||||
actualLocationQuery.Del("state")
|
||||
require.NotEmpty(t, actualLocationQuery.Get("code_challenge"))
|
||||
actualLocationQuery.Del("code_challenge")
|
||||
require.NotEmpty(t, actualLocationQuery.Get("code_challenge_method"))
|
||||
actualLocationQuery.Del("code_challenge_method")
|
||||
require.NotEmpty(t, actualLocationQuery.Get("nonce"))
|
||||
actualLocationQuery.Del("nonce")
|
||||
require.Equal(t, expectedLocationQuery, actualLocationQuery)
|
||||
|
||||
// Zero-out query values, since we made specific assertions about those above, and assert that the
|
||||
// URL's are equal otherwise.
|
||||
expectedLocationURL.RawQuery = ""
|
||||
actualLocationURL.RawQuery = ""
|
||||
require.Equal(t, expectedLocationURL, actualLocationURL)
|
||||
func (s *localCallbackServer) waitForCallback(timeout time.Duration) *http.Request {
|
||||
select {
|
||||
case callback := <-s.callbacks:
|
||||
return callback
|
||||
case <-time.After(timeout):
|
||||
require.Fail(s.t, "timed out waiting for callback request")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -4,13 +4,10 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1"
|
||||
@ -28,7 +25,7 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) {
|
||||
SecretName: "does-not-exist",
|
||||
},
|
||||
}
|
||||
upstream := makeTestUpstream(t, spec, v1alpha1.PhaseError)
|
||||
upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseError)
|
||||
expectUpstreamConditions(t, upstream, []v1alpha1.Condition{
|
||||
{
|
||||
Type: "ClientCredentialsValid",
|
||||
@ -56,10 +53,10 @@ func TestSupervisorUpstreamOIDCDiscovery(t *testing.T) {
|
||||
AdditionalScopes: []string{"email", "profile"},
|
||||
},
|
||||
Client: v1alpha1.OIDCClient{
|
||||
SecretName: makeTestClientCredsSecret(t, "test-client-id", "test-client-secret").Name,
|
||||
SecretName: library.CreateClientCredsSecret(t, "test-client-id", "test-client-secret").Name,
|
||||
},
|
||||
}
|
||||
upstream := makeTestUpstream(t, spec, v1alpha1.PhaseReady)
|
||||
upstream := library.CreateTestUpstreamOIDCProvider(t, spec, v1alpha1.PhaseReady)
|
||||
expectUpstreamConditions(t, upstream, []v1alpha1.Condition{
|
||||
{
|
||||
Type: "ClientCredentialsValid",
|
||||
@ -87,74 +84,3 @@ func expectUpstreamConditions(t *testing.T, upstream *v1alpha1.UpstreamOIDCProvi
|
||||
}
|
||||
require.ElementsMatch(t, expected, normalized)
|
||||
}
|
||||
|
||||
func makeTestClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret {
|
||||
t.Helper()
|
||||
env := library.IntegrationEnv(t)
|
||||
client := library.NewClientset(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
created, err := client.CoreV1().Secrets(env.SupervisorNamespace).Create(ctx, &corev1.Secret{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Namespace: env.SupervisorNamespace,
|
||||
GenerateName: "test-client-creds-",
|
||||
Labels: map[string]string{"pinniped.dev/test": ""},
|
||||
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
|
||||
},
|
||||
Type: "secrets.pinniped.dev/oidc-client",
|
||||
StringData: map[string]string{
|
||||
"clientID": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
},
|
||||
}, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.CoreV1().Secrets(env.SupervisorNamespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Logf("created test client credentials Secret %s", created.Name)
|
||||
return created
|
||||
}
|
||||
|
||||
func makeTestUpstream(t *testing.T, spec v1alpha1.UpstreamOIDCProviderSpec, expectedPhase v1alpha1.UpstreamOIDCProviderPhase) *v1alpha1.UpstreamOIDCProvider {
|
||||
t.Helper()
|
||||
env := library.IntegrationEnv(t)
|
||||
client := library.NewSupervisorClientset(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Create the UpstreamOIDCProvider using GenerateName to get a random name.
|
||||
created, err := client.IDPV1alpha1().
|
||||
UpstreamOIDCProviders(env.SupervisorNamespace).
|
||||
Create(ctx, &v1alpha1.UpstreamOIDCProvider{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
Namespace: env.SupervisorNamespace,
|
||||
GenerateName: "test-upstream-",
|
||||
Labels: map[string]string{"pinniped.dev/test": ""},
|
||||
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
|
||||
},
|
||||
Spec: spec,
|
||||
}, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Always clean this up after this point.
|
||||
t.Cleanup(func() {
|
||||
err := client.IDPV1alpha1().
|
||||
UpstreamOIDCProviders(env.SupervisorNamespace).
|
||||
Delete(context.Background(), created.Name, metav1.DeleteOptions{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Logf("created test UpstreamOIDCProvider %s", created.Name)
|
||||
|
||||
// Wait for the UpstreamOIDCProvider to enter the expected phase (or time out).
|
||||
var result *v1alpha1.UpstreamOIDCProvider
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var err error
|
||||
result, err = client.IDPV1alpha1().
|
||||
UpstreamOIDCProviders(created.Namespace).Get(ctx, created.Name, metav1.GetOptions{})
|
||||
require.NoError(t, err)
|
||||
return result.Status.Phase == expectedPhase
|
||||
}, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase)
|
||||
return result
|
||||
}
|
||||
|
158
test/library/browsertest/browsertest.go
Normal file
158
test/library/browsertest/browsertest.go
Normal file
@ -0,0 +1,158 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package browsertest provides integration test helpers for our browser-based tests.
|
||||
package browsertest
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sclevine/agouti"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pinniped.dev/test/library"
|
||||
)
|
||||
|
||||
const (
|
||||
operationTimeout = 10 * time.Second
|
||||
operationPollingInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// Open a webdriver-driven browser and returns an *agouti.Page to control it. The browser will be automatically
|
||||
// closed at the end of the current test. It is configured for test purposes with the correct HTTP proxy and
|
||||
// in a mode that ignore certificate errors.
|
||||
func Open(t *testing.T) *agouti.Page {
|
||||
t.Logf("opening browser driver")
|
||||
env := library.IntegrationEnv(t)
|
||||
caps := agouti.NewCapabilities()
|
||||
if env.Proxy != "" {
|
||||
t.Logf("configuring Chrome to use proxy %q", env.Proxy)
|
||||
caps = caps.Proxy(agouti.ProxyConfig{
|
||||
ProxyType: "manual",
|
||||
HTTPProxy: env.Proxy,
|
||||
SSLProxy: env.Proxy,
|
||||
NoProxy: "127.0.0.1",
|
||||
})
|
||||
}
|
||||
agoutiDriver := agouti.ChromeDriver(
|
||||
agouti.Desired(caps),
|
||||
agouti.ChromeOptions("args", []string{
|
||||
"--no-sandbox",
|
||||
"--ignore-certificate-errors",
|
||||
"--headless", // Comment out this line to see the tests happen in a visible browser window.
|
||||
}),
|
||||
// Uncomment this to see stdout/stderr from chromedriver.
|
||||
// agouti.Debug,
|
||||
)
|
||||
require.NoError(t, agoutiDriver.Start())
|
||||
t.Cleanup(func() { require.NoError(t, agoutiDriver.Stop()) })
|
||||
page, err := agoutiDriver.NewPage(agouti.Browser("chrome"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, page.Reset())
|
||||
return page
|
||||
}
|
||||
|
||||
// WaitForVisibleElements expects the page to contain all the the elements specified by the selectors. It waits for this
|
||||
// to occur and times out, failing the test, if they never appear.
|
||||
func WaitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) {
|
||||
t.Helper()
|
||||
|
||||
require.Eventuallyf(t,
|
||||
func() bool {
|
||||
for _, sel := range selectors {
|
||||
vis, err := page.First(sel).Visible()
|
||||
if !(err == nil && vis) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
operationTimeout,
|
||||
operationPollingInterval,
|
||||
"expected to have a page with selectors %v, but it never loaded",
|
||||
selectors,
|
||||
)
|
||||
}
|
||||
|
||||
// WaitForURL expects the page to eventually navigate to a URL matching the specified pattern. It waits for this
|
||||
// to occur and times out, failing the test, if it never does.
|
||||
func WaitForURL(t *testing.T, page *agouti.Page, pat *regexp.Regexp) {
|
||||
var lastURL string
|
||||
require.Eventuallyf(t,
|
||||
func() bool {
|
||||
url, err := page.URL()
|
||||
if err == nil && pat.MatchString(url) {
|
||||
return true
|
||||
}
|
||||
if url != lastURL {
|
||||
t.Logf("saw URL %s", url)
|
||||
lastURL = url
|
||||
}
|
||||
return false
|
||||
},
|
||||
operationTimeout,
|
||||
operationPollingInterval,
|
||||
"expected to browse to %s, but never got there",
|
||||
pat,
|
||||
)
|
||||
}
|
||||
|
||||
// LoginToUpstream expects the page to be redirected to one of several known upstream IDPs.
|
||||
// It knows how to enter the test username/password and submit the upstream login form.
|
||||
func LoginToUpstream(t *testing.T, page *agouti.Page, upstream library.TestOIDCUpstream) {
|
||||
t.Helper()
|
||||
|
||||
type config struct {
|
||||
Name string
|
||||
IssuerPattern *regexp.Regexp
|
||||
LoginPagePattern *regexp.Regexp
|
||||
UsernameSelector string
|
||||
PasswordSelector string
|
||||
LoginButtonSelector string
|
||||
}
|
||||
|
||||
// Lookup the provider by matching on the issuer URL.
|
||||
var cfg *config
|
||||
for _, p := range []*config{
|
||||
{
|
||||
Name: "Okta",
|
||||
IssuerPattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
|
||||
LoginPagePattern: regexp.MustCompile(`\Ahttps://.+\.okta\.com/.+\z`),
|
||||
UsernameSelector: "input#okta-signin-username",
|
||||
PasswordSelector: "input#okta-signin-password",
|
||||
LoginButtonSelector: "input#okta-signin-submit",
|
||||
},
|
||||
{
|
||||
Name: "Dex",
|
||||
IssuerPattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex.*\z`),
|
||||
LoginPagePattern: regexp.MustCompile(`\Ahttps://dex\.dex\.svc\.cluster\.local/dex/auth/local.+\z`),
|
||||
UsernameSelector: "input#login",
|
||||
PasswordSelector: "input#password",
|
||||
LoginButtonSelector: "button#submit-login",
|
||||
},
|
||||
} {
|
||||
if p.IssuerPattern.MatchString(upstream.Issuer) {
|
||||
cfg = p
|
||||
break
|
||||
}
|
||||
}
|
||||
if cfg == nil {
|
||||
require.Failf(t, "could not find login provider for issuer %q", upstream.Issuer)
|
||||
return
|
||||
}
|
||||
|
||||
// Expect to be redirected to the login page.
|
||||
t.Logf("waiting for redirect to %s login page", cfg.Name)
|
||||
WaitForURL(t, page, cfg.LoginPagePattern)
|
||||
|
||||
// Wait for the login page to be rendered.
|
||||
WaitForVisibleElements(t, page, cfg.UsernameSelector, cfg.PasswordSelector, cfg.LoginButtonSelector)
|
||||
|
||||
// Fill in the username and password and click "submit".
|
||||
t.Logf("logging into %s", cfg.Name)
|
||||
require.NoError(t, page.First(cfg.UsernameSelector).Fill(upstream.Username))
|
||||
require.NoError(t, page.First(cfg.PasswordSelector).Fill(upstream.Password))
|
||||
require.NoError(t, page.First(cfg.LoginButtonSelector).Click())
|
||||
}
|
@ -25,6 +25,7 @@ import (
|
||||
|
||||
auth1alpha1 "go.pinniped.dev/generated/1.19/apis/concierge/authentication/v1alpha1"
|
||||
configv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/config/v1alpha1"
|
||||
idpv1alpha1 "go.pinniped.dev/generated/1.19/apis/supervisor/idp/v1alpha1"
|
||||
conciergeclientset "go.pinniped.dev/generated/1.19/client/concierge/clientset/versioned"
|
||||
supervisorclientset "go.pinniped.dev/generated/1.19/client/supervisor/clientset/versioned"
|
||||
|
||||
@ -140,11 +141,7 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty
|
||||
defer cancel()
|
||||
|
||||
webhook, err := webhooks.Create(createContext, &auth1alpha1.WebhookAuthenticator{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
GenerateName: "test-webhook-",
|
||||
Labels: map[string]string{"pinniped.dev/test": ""},
|
||||
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
|
||||
},
|
||||
ObjectMeta: testObjectMeta(t, "webhook"),
|
||||
Spec: testEnv.TestWebhook,
|
||||
}, metav1.CreateOptions{})
|
||||
require.NoError(t, err, "could not create test WebhookAuthenticator")
|
||||
@ -172,7 +169,7 @@ func CreateTestWebhookAuthenticator(ctx context.Context, t *testing.T) corev1.Ty
|
||||
//
|
||||
// If the provided issuer is not the empty string, then it will be used for the
|
||||
// OIDCProvider.Spec.Issuer field. Else, a random issuer will be generated.
|
||||
func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecretName string) *configv1alpha1.OIDCProvider {
|
||||
func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer string, certSecretName string, expectStatus configv1alpha1.OIDCProviderStatusCondition) *configv1alpha1.OIDCProvider {
|
||||
t.Helper()
|
||||
testEnv := IntegrationEnv(t)
|
||||
|
||||
@ -180,18 +177,12 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre
|
||||
defer cancel()
|
||||
|
||||
if issuer == "" {
|
||||
var err error
|
||||
issuer, err = randomIssuer()
|
||||
require.NoError(t, err)
|
||||
issuer = randomIssuer(t)
|
||||
}
|
||||
|
||||
opcs := NewSupervisorClientset(t).ConfigV1alpha1().OIDCProviders(testEnv.SupervisorNamespace)
|
||||
opc, err := opcs.Create(createContext, &configv1alpha1.OIDCProvider{
|
||||
ObjectMeta: metav1.ObjectMeta{
|
||||
GenerateName: "test-oidc-provider-",
|
||||
Labels: map[string]string{"pinniped.dev/test": ""},
|
||||
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
|
||||
},
|
||||
ObjectMeta: testObjectMeta(t, "oidc-provider"),
|
||||
Spec: configv1alpha1.OIDCProviderSpec{
|
||||
Issuer: issuer,
|
||||
TLS: &configv1alpha1.OIDCProviderTLSSpec{SecretName: certSecretName},
|
||||
@ -213,13 +204,103 @@ func CreateTestOIDCProvider(ctx context.Context, t *testing.T, issuer, certSecre
|
||||
}
|
||||
})
|
||||
|
||||
// If we're not expecting any particular status, just return the new OIDCProvider immediately.
|
||||
if expectStatus == "" {
|
||||
return opc
|
||||
}
|
||||
|
||||
func randomIssuer() (string, error) {
|
||||
// Wait for the OIDCProvider to enter the expected phase (or time out).
|
||||
var result *configv1alpha1.OIDCProvider
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var err error
|
||||
result, err = opcs.Get(ctx, opc.Name, metav1.GetOptions{})
|
||||
require.NoError(t, err)
|
||||
return result.Status.Status == expectStatus
|
||||
}, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectStatus)
|
||||
|
||||
return opc
|
||||
}
|
||||
|
||||
func randomIssuer(t *testing.T) string {
|
||||
var buf [8]byte
|
||||
if _, err := io.ReadFull(rand.Reader, buf[:]); err != nil {
|
||||
return "", fmt.Errorf("could not generate random state: %w", err)
|
||||
_, err := io.ReadFull(rand.Reader, buf[:])
|
||||
require.NoError(t, err)
|
||||
return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:]))
|
||||
}
|
||||
|
||||
func CreateTestSecret(t *testing.T, namespace string, baseName string, secretType string, stringData map[string]string) *corev1.Secret {
|
||||
t.Helper()
|
||||
client := NewClientset(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
created, err := client.CoreV1().Secrets(namespace).Create(ctx, &corev1.Secret{
|
||||
ObjectMeta: testObjectMeta(t, baseName),
|
||||
Type: corev1.SecretType(secretType),
|
||||
StringData: stringData,
|
||||
}, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err := client.CoreV1().Secrets(namespace).Delete(context.Background(), created.Name, metav1.DeleteOptions{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Logf("created test Secret %s", created.Name)
|
||||
return created
|
||||
}
|
||||
|
||||
func CreateClientCredsSecret(t *testing.T, clientID string, clientSecret string) *corev1.Secret {
|
||||
t.Helper()
|
||||
env := IntegrationEnv(t)
|
||||
return CreateTestSecret(t,
|
||||
env.SupervisorNamespace,
|
||||
"test-client-creds-",
|
||||
"secrets.pinniped.dev/oidc-client",
|
||||
map[string]string{
|
||||
"clientID": clientID,
|
||||
"clientSecret": clientSecret,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func CreateTestUpstreamOIDCProvider(t *testing.T, spec idpv1alpha1.UpstreamOIDCProviderSpec, expectedPhase idpv1alpha1.UpstreamOIDCProviderPhase) *idpv1alpha1.UpstreamOIDCProvider {
|
||||
t.Helper()
|
||||
env := IntegrationEnv(t)
|
||||
client := NewSupervisorClientset(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Create the UpstreamOIDCProvider using GenerateName to get a random name.
|
||||
upstreams := client.IDPV1alpha1().UpstreamOIDCProviders(env.SupervisorNamespace)
|
||||
|
||||
created, err := upstreams.Create(ctx, &idpv1alpha1.UpstreamOIDCProvider{
|
||||
ObjectMeta: testObjectMeta(t, "upstream"),
|
||||
Spec: spec,
|
||||
}, metav1.CreateOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Always clean this up after this point.
|
||||
t.Cleanup(func() {
|
||||
err := upstreams.Delete(context.Background(), created.Name, metav1.DeleteOptions{})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Logf("created test UpstreamOIDCProvider %s", created.Name)
|
||||
|
||||
// Wait for the UpstreamOIDCProvider to enter the expected phase (or time out).
|
||||
var result *idpv1alpha1.UpstreamOIDCProvider
|
||||
require.Eventuallyf(t, func() bool {
|
||||
var err error
|
||||
result, err = upstreams.Get(ctx, created.Name, metav1.GetOptions{})
|
||||
require.NoError(t, err)
|
||||
return result.Status.Phase == expectedPhase
|
||||
}, 60*time.Second, 1*time.Second, "expected the UpstreamOIDCProvider to go into phase %s", expectedPhase)
|
||||
return result
|
||||
}
|
||||
|
||||
func testObjectMeta(t *testing.T, baseName string) metav1.ObjectMeta {
|
||||
return metav1.ObjectMeta{
|
||||
GenerateName: fmt.Sprintf("test-%s-", baseName),
|
||||
Labels: map[string]string{"pinniped.dev/test": ""},
|
||||
Annotations: map[string]string{"pinniped.dev/testName": t.Name()},
|
||||
}
|
||||
return fmt.Sprintf("http://test-issuer-%s.pinniped.dev", hex.EncodeToString(buf[:])), nil
|
||||
}
|
||||
|
49
test/library/dumplogs.go
Normal file
49
test/library/dumplogs.go
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
)
|
||||
|
||||
// DumpLogs is meant to be called in a `defer` to dump the logs of components in the cluster on a test failure.
|
||||
func DumpLogs(t *testing.T, namespace string) {
|
||||
// Only trigger on failed tests.
|
||||
if !t.Failed() {
|
||||
return
|
||||
}
|
||||
|
||||
kubeClient := NewClientset(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
logTailLines := int64(40)
|
||||
pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, pod := range pods.Items {
|
||||
for _, container := range pod.Status.ContainerStatuses {
|
||||
t.Logf("pod %s/%s container %s restarted %d times:", pod.Namespace, pod.Name, container.Name, container.RestartCount)
|
||||
req := kubeClient.CoreV1().Pods(namespace).GetLogs(pod.Name, &corev1.PodLogOptions{
|
||||
Container: container.Name,
|
||||
TailLines: &logTailLines,
|
||||
})
|
||||
logReader, err := req.Stream(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
scanner := bufio.NewScanner(logReader)
|
||||
for scanner.Scan() {
|
||||
t.Logf("%s/%s/%s > %s", pod.Namespace, pod.Name, container.Name, scanner.Text())
|
||||
}
|
||||
require.NoError(t, scanner.Err())
|
||||
}
|
||||
}
|
||||
}
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -26,18 +27,22 @@ func (l *testlogReader) Read(p []byte) (n int, err error) {
|
||||
l.t.Helper()
|
||||
n, err = l.r.Read(p)
|
||||
if err != nil {
|
||||
l.t.Logf("%s > %q: %v", l.name, maskTokens(p[0:n]), err)
|
||||
l.t.Logf("%s > %q: %v", l.name, MaskTokens(string(p[0:n])), err)
|
||||
} else {
|
||||
l.t.Logf("%s > %q", l.name, maskTokens(p[0:n]))
|
||||
l.t.Logf("%s > %q", l.name, MaskTokens(string(p[0:n])))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//nolint: gochecknoglobals
|
||||
// MaskTokens makes a best-effort attempt to mask out things that look like secret tokens in test output.
|
||||
// The goal is more to have readable test output than for any security reason.
|
||||
func MaskTokens(in string) string {
|
||||
var tokenLike = regexp.MustCompile(`(?mi)[a-zA-Z0-9._-]{30,}|[a-zA-Z0-9]{20,}`)
|
||||
|
||||
func maskTokens(in []byte) string {
|
||||
return tokenLike.ReplaceAllStringFunc(string(in), func(t string) string {
|
||||
return tokenLike.ReplaceAllStringFunc(in, func(t string) string {
|
||||
// This is a silly heuristic, but things with multiple dots are more likely hostnames that we don't want masked.
|
||||
if strings.Count(t, ".") >= 4 {
|
||||
return t
|
||||
}
|
||||
return fmt.Sprintf("[...%d bytes...]", len(t))
|
||||
})
|
||||
}
|
||||
|
16
test/library/iplookup.go
Normal file
16
test/library/iplookup.go
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// +build !go1.14
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// LookupIP looks up the IP address of the provided hostname, preferring IPv4.
|
||||
func LookupIP(ctx context.Context, hostname string) ([]net.IP, error) {
|
||||
return net.DefaultResolver.LookupIP(ctx, "ip4", hostname)
|
||||
}
|
28
test/library/iplookup_go1.14.go
Normal file
28
test/library/iplookup_go1.14.go
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// +build go1.14
|
||||
|
||||
package library
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
// LookupIP looks up the IP address of the provided hostname, preferring IPv4.
|
||||
func LookupIP(ctx context.Context, hostname string) ([]net.IP, error) {
|
||||
ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Filter out to only IPv4 addresses
|
||||
var results []net.IP
|
||||
for _, ip := range ips {
|
||||
if ip.IP.To4() != nil {
|
||||
results = append(results, ip.IP)
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user