Merge pull request #687 from mattmoyer/add-response-mode-form-post

Add support for "response_mode=form_post" in Supervisor and CLI.
This commit is contained in:
Matt Moyer 2021-07-09 10:37:59 -07:00 committed by GitHub
commit 405a27ba90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1498 additions and 178 deletions

View File

@ -61,6 +61,7 @@ type getKubeconfigOIDCParams struct {
listenPort uint16 listenPort uint16
scopes []string scopes []string
skipBrowser bool skipBrowser bool
skipListen bool
sessionCachePath string sessionCachePath string
debugSessionCache bool debugSessionCache bool
caBundle caBundleFlag caBundle caBundleFlag
@ -146,6 +147,7 @@ func kubeconfigCommand(deps kubeconfigDeps) *cobra.Command {
f.Uint16Var(&flags.oidc.listenPort, "oidc-listen-port", 0, "TCP port for localhost listener (authorization code flow only)") f.Uint16Var(&flags.oidc.listenPort, "oidc-listen-port", 0, "TCP port for localhost listener (authorization code flow only)")
f.StringSliceVar(&flags.oidc.scopes, "oidc-scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "pinniped:request-audience"}, "OpenID Connect scopes to request during login") f.StringSliceVar(&flags.oidc.scopes, "oidc-scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "pinniped:request-audience"}, "OpenID Connect scopes to request during login")
f.BoolVar(&flags.oidc.skipBrowser, "oidc-skip-browser", false, "During OpenID Connect login, skip opening the browser (just print the URL)") f.BoolVar(&flags.oidc.skipBrowser, "oidc-skip-browser", false, "During OpenID Connect login, skip opening the browser (just print the URL)")
f.BoolVar(&flags.oidc.skipListen, "oidc-skip-listen", false, "During OpenID Connect login, skip starting a localhost callback listener (manual copy/paste flow only)")
f.StringVar(&flags.oidc.sessionCachePath, "oidc-session-cache", "", "Path to OpenID Connect session cache file") f.StringVar(&flags.oidc.sessionCachePath, "oidc-session-cache", "", "Path to OpenID Connect session cache file")
f.Var(&flags.oidc.caBundle, "oidc-ca-bundle", "Path to TLS certificate authority bundle (PEM format, optional, can be repeated)") f.Var(&flags.oidc.caBundle, "oidc-ca-bundle", "Path to TLS certificate authority bundle (PEM format, optional, can be repeated)")
f.BoolVar(&flags.oidc.debugSessionCache, "oidc-debug-session-cache", false, "Print debug logs related to the OpenID Connect session cache") f.BoolVar(&flags.oidc.debugSessionCache, "oidc-debug-session-cache", false, "Print debug logs related to the OpenID Connect session cache")
@ -161,6 +163,9 @@ func kubeconfigCommand(deps kubeconfigDeps) *cobra.Command {
f.StringVar(&flags.credentialCachePath, "credential-cache", "", "Path to cluster-specific credentials cache") f.StringVar(&flags.credentialCachePath, "credential-cache", "", "Path to cluster-specific credentials cache")
mustMarkHidden(cmd, "oidc-debug-session-cache") mustMarkHidden(cmd, "oidc-debug-session-cache")
// --oidc-skip-listen is mainly needed for testing. We'll leave it hidden until we have a non-testing use case.
mustMarkHidden(cmd, "oidc-skip-listen")
mustMarkDeprecated(cmd, "concierge-namespace", "not needed anymore") mustMarkDeprecated(cmd, "concierge-namespace", "not needed anymore")
mustMarkHidden(cmd, "concierge-namespace") mustMarkHidden(cmd, "concierge-namespace")
@ -317,6 +322,9 @@ func newExecConfig(deps kubeconfigDeps, flags getKubeconfigParams) (*clientcmdap
if flags.oidc.skipBrowser { if flags.oidc.skipBrowser {
execConfig.Args = append(execConfig.Args, "--skip-browser") execConfig.Args = append(execConfig.Args, "--skip-browser")
} }
if flags.oidc.skipListen {
execConfig.Args = append(execConfig.Args, "--skip-listen")
}
if flags.oidc.listenPort != 0 { if flags.oidc.listenPort != 0 {
execConfig.Args = append(execConfig.Args, "--listen-port="+strconv.Itoa(int(flags.oidc.listenPort))) execConfig.Args = append(execConfig.Args, "--listen-port="+strconv.Itoa(int(flags.oidc.listenPort)))
} }

View File

@ -1352,6 +1352,7 @@ func TestGetKubeconfig(t *testing.T) {
"--concierge-ca-bundle", testConciergeCABundlePath, "--concierge-ca-bundle", testConciergeCABundlePath,
"--oidc-issuer", issuerURL, "--oidc-issuer", issuerURL,
"--oidc-skip-browser", "--oidc-skip-browser",
"--oidc-skip-listen",
"--oidc-listen-port", "1234", "--oidc-listen-port", "1234",
"--oidc-ca-bundle", f.Name(), "--oidc-ca-bundle", f.Name(),
"--oidc-session-cache", "/path/to/cache/dir/sessions.yaml", "--oidc-session-cache", "/path/to/cache/dir/sessions.yaml",
@ -1405,6 +1406,7 @@ func TestGetKubeconfig(t *testing.T) {
- --client-id=pinniped-cli - --client-id=pinniped-cli
- --scopes=offline_access,openid,pinniped:request-audience - --scopes=offline_access,openid,pinniped:request-audience
- --skip-browser - --skip-browser
- --skip-listen
- --listen-port=1234 - --listen-port=1234
- --ca-bundle-data=%s - --ca-bundle-data=%s
- --session-cache=/path/to/cache/dir/sessions.yaml - --session-cache=/path/to/cache/dir/sessions.yaml

View File

@ -59,6 +59,7 @@ type oidcLoginFlags struct {
listenPort uint16 listenPort uint16
scopes []string scopes []string
skipBrowser bool skipBrowser bool
skipListen bool
sessionCachePath string sessionCachePath string
caBundlePaths []string caBundlePaths []string
caBundleData []string caBundleData []string
@ -91,6 +92,7 @@ func oidcLoginCommand(deps oidcLoginCommandDeps) *cobra.Command {
cmd.Flags().Uint16Var(&flags.listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only)") cmd.Flags().Uint16Var(&flags.listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only)")
cmd.Flags().StringSliceVar(&flags.scopes, "scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "pinniped:request-audience"}, "OIDC scopes to request during login") cmd.Flags().StringSliceVar(&flags.scopes, "scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "pinniped:request-audience"}, "OIDC scopes to request during login")
cmd.Flags().BoolVar(&flags.skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL)") cmd.Flags().BoolVar(&flags.skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL)")
cmd.Flags().BoolVar(&flags.skipListen, "skip-listen", false, "Skip starting a localhost callback listener (manual copy/paste flow only)")
cmd.Flags().StringVar(&flags.sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file") cmd.Flags().StringVar(&flags.sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file")
cmd.Flags().StringSliceVar(&flags.caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated)") cmd.Flags().StringSliceVar(&flags.caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated)")
cmd.Flags().StringSliceVar(&flags.caBundleData, "ca-bundle-data", nil, "Base64 encoded TLS certificate authority bundle (base64 encoded PEM format, optional, can be repeated)") cmd.Flags().StringSliceVar(&flags.caBundleData, "ca-bundle-data", nil, "Base64 encoded TLS certificate authority bundle (base64 encoded PEM format, optional, can be repeated)")
@ -107,6 +109,8 @@ func oidcLoginCommand(deps oidcLoginCommandDeps) *cobra.Command {
cmd.Flags().StringVar(&flags.upstreamIdentityProviderName, "upstream-identity-provider-name", "", "The name of the upstream identity provider used during login with a Supervisor") cmd.Flags().StringVar(&flags.upstreamIdentityProviderName, "upstream-identity-provider-name", "", "The name of the upstream identity provider used during login with a Supervisor")
cmd.Flags().StringVar(&flags.upstreamIdentityProviderType, "upstream-identity-provider-type", "oidc", "The type of the upstream identity provider used during login with a Supervisor (e.g. 'oidc', 'ldap')") cmd.Flags().StringVar(&flags.upstreamIdentityProviderType, "upstream-identity-provider-type", "oidc", "The type of the upstream identity provider used during login with a Supervisor (e.g. 'oidc', 'ldap')")
// --skip-listen is mainly needed for testing. We'll leave it hidden until we have a non-testing use case.
mustMarkHidden(cmd, "skip-listen")
mustMarkHidden(cmd, "debug-session-cache") mustMarkHidden(cmd, "debug-session-cache")
mustMarkRequired(cmd, "issuer") mustMarkRequired(cmd, "issuer")
cmd.RunE = func(cmd *cobra.Command, args []string) error { return runOIDCLogin(cmd, deps, flags) } cmd.RunE = func(cmd *cobra.Command, args []string) error { return runOIDCLogin(cmd, deps, flags) }
@ -182,12 +186,14 @@ func runOIDCLogin(cmd *cobra.Command, deps oidcLoginCommandDeps, flags oidcLogin
} }
} }
// --skip-browser replaces the default "browser open" function with one that prints to stderr. // --skip-browser skips opening the browser.
if flags.skipBrowser { if flags.skipBrowser {
opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error { opts = append(opts, oidcclient.WithSkipBrowserOpen())
cmd.PrintErr("Please log in: ", url, "\n") }
return nil
})) // --skip-listen skips starting the localhost callback listener.
if flags.skipListen {
opts = append(opts, oidcclient.WithSkipListen())
} }
if len(flags.caBundlePaths) > 0 || len(flags.caBundleData) > 0 { if len(flags.caBundlePaths) > 0 || len(flags.caBundleData) > 0 {

View File

@ -226,6 +226,7 @@ func TestLoginOIDCCommand(t *testing.T) {
"--client-id", "test-client-id", "--client-id", "test-client-id",
"--issuer", "test-issuer", "--issuer", "test-issuer",
"--skip-browser", "--skip-browser",
"--skip-listen",
"--listen-port", "1234", "--listen-port", "1234",
"--debug-session-cache", "--debug-session-cache",
"--request-audience", "cluster-1234", "--request-audience", "cluster-1234",
@ -242,7 +243,7 @@ func TestLoginOIDCCommand(t *testing.T) {
"--upstream-identity-provider-type", "ldap", "--upstream-identity-provider-type", "ldap",
}, },
env: map[string]string{"PINNIPED_DEBUG": "true"}, env: map[string]string{"PINNIPED_DEBUG": "true"},
wantOptionsCount: 10, wantOptionsCount: 11,
wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"token":"exchanged-token"}}` + "\n", wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"token":"exchanged-token"}}` + "\n",
wantLogs: []string{ wantLogs: []string{
"\"level\"=0 \"msg\"=\"Pinniped login: Performing OIDC login\" \"client id\"=\"test-client-id\" \"issuer\"=\"test-issuer\"", "\"level\"=0 \"msg\"=\"Pinniped login: Performing OIDC login\" \"client id\"=\"test-client-id\" \"issuer\"=\"test-issuer\"",

3
go.mod
View File

@ -1,6 +1,6 @@
module go.pinniped.dev module go.pinniped.dev
go 1.14 go 1.16
require ( require (
github.com/MakeNowJust/heredoc/v2 v2.0.1 github.com/MakeNowJust/heredoc/v2 v2.0.1
@ -26,6 +26,7 @@ require (
github.com/spf13/cobra v1.2.1 github.com/spf13/cobra v1.2.1
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.7.0 github.com/stretchr/testify v1.7.0
github.com/tdewolff/minify/v2 v2.9.18
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023 golang.org/x/net v0.0.0-20210520170846-37e1c6afe023
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602

8
go.sum
View File

@ -118,6 +118,7 @@ github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927/go.mod h1:h/aW8ynjgkuj+NQRlZcDbAbM1ORAbXjXX77sX7T289U=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
@ -840,6 +841,7 @@ github.com/markbates/safe v1.0.0/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kN
github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0=
github.com/markbates/sigtx v1.0.0/go.mod h1:QF1Hv6Ic6Ca6W+T+DL0Y/ypborFKyvUY9HmuCD4VeTc= github.com/markbates/sigtx v1.0.0/go.mod h1:QF1Hv6Ic6Ca6W+T+DL0Y/ypborFKyvUY9HmuCD4VeTc=
github.com/markbates/willie v1.0.9/go.mod h1:fsrFVWl91+gXpx/6dv715j7i11fYPfZ9ZGfH0DQzY7w= github.com/markbates/willie v1.0.9/go.mod h1:fsrFVWl91+gXpx/6dv715j7i11fYPfZ9ZGfH0DQzY7w=
github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2/go.mod h1:0KeJpeMD6o+O4hW7qJOT7vyQPKrWmj26uf5wMc/IiIs=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
@ -1150,6 +1152,12 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/subosito/gotenv v1.1.1/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.1.1/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tdewolff/minify/v2 v2.9.18 h1:j5Is0sOGp4cxm0o3HgvHCWCvTtmKnfB0qv0FCRbmgZY=
github.com/tdewolff/minify/v2 v2.9.18/go.mod h1:0y0mXZnisZm8HcgQvAV0btxa1IgecGam90zMuHqEZuc=
github.com/tdewolff/parse/v2 v2.5.18 h1:d67Ql/Pe36JcJZ7J2MY8upx6iTxbxGS9lzwyFGtMmd0=
github.com/tdewolff/parse/v2 v2.5.18/go.mod h1:WzaJpRSbwq++EIQHYIRTpbYKNA3gn9it1Ik++q4zyho=
github.com/tdewolff/test v1.0.6 h1:76mzYJQ83Op284kMT+63iCNCI7NEERsIN8dLM+RiKr4=
github.com/tdewolff/test v1.0.6/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE=
github.com/tidwall/gjson v1.3.2/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.3.2/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
github.com/tidwall/gjson v1.6.8/go.mod h1:zeFuBCIqD4sN/gmqBzZ4j7Jd6UcA2Fc56x7QFsv+8fI= github.com/tidwall/gjson v1.6.8/go.mod h1:zeFuBCIqD4sN/gmqBzZ4j7Jd6UcA2Fc56x7QFsv+8fI=
github.com/tidwall/gjson v1.7.1/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= github.com/tidwall/gjson v1.7.1/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk=

View File

@ -1,16 +1,22 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package securityheader implements an HTTP middleware for setting security-related response headers. // Package securityheader implements an HTTP middleware for setting security-related response headers.
package securityheader package securityheader
import "net/http" import (
"net/http"
)
// Wrap the provided http.Handler so it sets appropriate security-related response headers. // Wrap the provided http.Handler so it sets appropriate security-related response headers.
func Wrap(wrapped http.Handler) http.Handler { func Wrap(wrapped http.Handler) http.Handler {
return WrapWithCustomCSP(wrapped, "default-src 'none'; frame-ancestors 'none'")
}
func WrapWithCustomCSP(wrapped http.Handler, cspHeader string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header() h := w.Header()
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") h.Set("Content-Security-Policy", cspHeader)
h.Set("X-Frame-Options", "DENY") h.Set("X-Frame-Options", "DENY")
h.Set("X-XSS-Protection", "1; mode=block") h.Set("X-XSS-Protection", "1; mode=block")
h.Set("X-Content-Type-Options", "nosniff") h.Set("X-Content-Type-Options", "nosniff")

View File

@ -16,7 +16,49 @@ import (
) )
func TestWrap(t *testing.T) { func TestWrap(t *testing.T) {
testServer := httptest.NewServer(Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, tt := range []struct {
name string
wrapFunc func(http.Handler) http.Handler
expectHeaders http.Header
}{
{
name: "wrap",
wrapFunc: Wrap,
expectHeaders: http.Header{
"X-Test-Header": []string{"test value"},
"Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"Referrer-Policy": []string{"no-referrer"},
"X-Content-Type-Options": []string{"nosniff"},
"X-Frame-Options": []string{"DENY"},
"X-Xss-Protection": []string{"1; mode=block"},
"X-Dns-Prefetch-Control": []string{"off"},
"Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"},
"Pragma": []string{"no-cache"},
"Expires": []string{"0"},
},
},
{
name: "custom CSP",
wrapFunc: func(h http.Handler) http.Handler { return WrapWithCustomCSP(h, "my-custom-csp-header") },
expectHeaders: http.Header{
"X-Test-Header": []string{"test value"},
"Content-Security-Policy": []string{"my-custom-csp-header"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"Referrer-Policy": []string{"no-referrer"},
"X-Content-Type-Options": []string{"nosniff"},
"X-Frame-Options": []string{"DENY"},
"X-Xss-Protection": []string{"1; mode=block"},
"X-Dns-Prefetch-Control": []string{"off"},
"Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"},
"Pragma": []string{"no-cache"},
"Expires": []string{"0"},
},
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
testServer := httptest.NewServer(tt.wrapFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test-Header", "test value") w.Header().Set("X-Test-Header", "test value")
_, _ = w.Write([]byte("hello world")) _, _ = w.Write([]byte("hello world"))
}))) })))
@ -36,20 +78,9 @@ func TestWrap(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "hello world", string(respBody)) require.Equal(t, "hello world", string(respBody))
expected := http.Header{ for key, values := range tt.expectHeaders {
"X-Test-Header": []string{"test value"},
"Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"Referrer-Policy": []string{"no-referrer"},
"X-Content-Type-Options": []string{"nosniff"},
"X-Frame-Options": []string{"DENY"},
"X-Xss-Protection": []string{"1; mode=block"},
"X-Dns-Prefetch-Control": []string{"off"},
"Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"},
"Pragma": []string{"no-cache"},
"Expires": []string{"0"},
}
for key, values := range expected {
assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key) assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key)
} }
})
}
} }

View File

@ -1156,7 +1156,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords) require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords)
case test.wantRedirectLocationRegexp != "": case test.wantRedirectLocationRegexp != "":
require.Len(t, rsp.Header().Values("Location"), 1) require.Len(t, rsp.Header().Values("Location"), 1)
oidctestutil.RequireAuthcodeRedirectLocation( oidctestutil.RequireAuthCodeRegexpMatch(
t, t,
rsp.Header().Get("Location"), rsp.Header().Get("Location"),
test.wantRedirectLocationRegexp, test.wantRedirectLocationRegexp,

View File

@ -18,6 +18,7 @@ import (
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -35,7 +36,7 @@ func NewHandler(
stateDecoder, cookieDecoder oidc.Decoder, stateDecoder, cookieDecoder oidc.Decoder,
redirectURI string, redirectURI string,
) http.Handler { ) http.Handler {
return securityheader.Wrap(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { handler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder) state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil { if err != nil {
return err return err
@ -97,7 +98,8 @@ func NewHandler(
oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder) oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder)
return nil return nil
})) })
return securityheader.WrapWithCustomCSP(handler, formposthtml.ContentSecurityPolicy())
} }
func authcode(r *http.Request) string { func authcode(r *http.Request) string {

View File

@ -122,6 +122,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantContentType string wantContentType string
wantBody string wantBody string
wantRedirectLocationRegexp string wantRedirectLocationRegexp string
wantBodyFormResponseRegexp string
wantDownstreamGrantedScopes []string wantDownstreamGrantedScopes []string
wantDownstreamIDTokenSubject string wantDownstreamIDTokenSubject string
wantDownstreamIDTokenUsername string wantDownstreamIDTokenUsername string
@ -133,6 +134,32 @@ func TestCallbackEndpoint(t *testing.T) {
wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs
}{ }{
{
name: "GET with good state and cookie and successful upstream token exchange with response_mode=form_post returns 200 with HTML+JS form",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().WithAuthorizeRequestParams(
shallowCopyAndModifyQuery(
happyDownstreamRequestParamsQuery,
map[string]string{"response_mode": "form_post"},
).Encode(),
).Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusOK,
wantContentType: "text/html;charset=UTF-8",
wantBodyFormResponseRegexp: `<code id="manual-auth-code">(.+)</code>`,
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject,
wantDownstreamIDTokenUsername: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantDownstreamGrantedScopes: happyDownstreamScopesGranted,
wantDownstreamNonce: downstreamNonce,
wantDownstreamPKCEChallenge: downstreamPKCEChallenge,
wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{ {
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
idp: happyUpstream().Build(), idp: happyUpstream().Build(),
@ -666,15 +693,40 @@ func TestCallbackEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
if test.wantBody != "" { switch {
// If we want a specific static response body, assert that.
case test.wantBody != "":
require.Equal(t, test.wantBody, rsp.Body.String()) require.Equal(t, test.wantBody, rsp.Body.String())
} else {
// Else if we want a body that contains a regex-matched auth code, assert that (for "response_mode=form_post").
case test.wantBodyFormResponseRegexp != "":
oidctestutil.RequireAuthCodeRegexpMatch(
t,
rsp.Body.String(),
test.wantBodyFormResponseRegexp,
client,
secrets,
oauthStore,
test.wantDownstreamGrantedScopes,
test.wantDownstreamIDTokenSubject,
test.wantDownstreamIDTokenUsername,
test.wantDownstreamIDTokenGroups,
test.wantDownstreamRequestedScopes,
test.wantDownstreamPKCEChallenge,
test.wantDownstreamPKCEChallengeMethod,
test.wantDownstreamNonce,
downstreamClientID,
downstreamRedirectURI,
)
// Otherwise, expect an empty response body.
default:
require.Empty(t, rsp.Body.String()) require.Empty(t, rsp.Body.String())
} }
if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test
require.Len(t, rsp.Header().Values("Location"), 1) require.Len(t, rsp.Header().Values("Location"), 1)
oidctestutil.RequireAuthcodeRedirectLocation( oidctestutil.RequireAuthCodeRegexpMatch(
t, t,
rsp.Header().Get("Location"), rsp.Header().Get("Location"),
test.wantRedirectLocationRegexp, test.wantRedirectLocationRegexp,

View File

@ -18,10 +18,16 @@ type Client struct {
fosite.DefaultOpenIDConnectClient fosite.DefaultOpenIDConnectClient
} }
// It implements both the base and OIDC client interfaces of Fosite. func (c Client) GetResponseModes() []fosite.ResponseModeType {
// For now, all Pinniped clients always support "" (unspecified), "query", and "form_post" response modes.
return []fosite.ResponseModeType{fosite.ResponseModeDefault, fosite.ResponseModeQuery, fosite.ResponseModeFormPost}
}
// It implements both the base, OIDC, and response_mode client interfaces of Fosite.
var ( var (
_ fosite.Client = (*Client)(nil) _ fosite.Client = (*Client)(nil)
_ fosite.OpenIDConnectClient = (*Client)(nil) _ fosite.OpenIDConnectClient = (*Client)(nil)
_ fosite.ResponseModeClient = (*Client)(nil)
) )
// StaticClientManager is a fosite.ClientManager with statically-defined clients. // StaticClientManager is a fosite.ClientManager with statically-defined clients.

View File

@ -59,6 +59,7 @@ func TestPinnipedCLI(t *testing.T) {
require.Equal(t, "", c.GetRequestObjectSigningAlgorithm()) require.Equal(t, "", c.GetRequestObjectSigningAlgorithm())
require.Equal(t, "none", c.GetTokenEndpointAuthMethod()) require.Equal(t, "none", c.GetTokenEndpointAuthMethod())
require.Equal(t, "RS256", c.GetTokenEndpointAuthSigningAlgorithm()) require.Equal(t, "RS256", c.GetTokenEndpointAuthSigningAlgorithm())
require.Equal(t, []fosite.ResponseModeType{"", "query", "form_post"}, c.GetResponseModes())
marshaled, err := json.Marshal(c) marshaled, err := json.Marshal(c)
require.NoError(t, err) require.NoError(t, err)

View File

@ -25,6 +25,7 @@ type Metadata struct {
JWKSURI string `json:"jwks_uri"` JWKSURI string `json:"jwks_uri"`
ResponseTypesSupported []string `json:"response_types_supported"` ResponseTypesSupported []string `json:"response_types_supported"`
ResponseModesSupported []string `json:"response_modes_supported"`
SubjectTypesSupported []string `json:"subject_types_supported"` SubjectTypesSupported []string `json:"subject_types_supported"`
IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
@ -63,6 +64,7 @@ func NewHandler(issuerURL string) http.Handler {
JWKSURI: issuerURL + oidc.JWKSEndpointPath, JWKSURI: issuerURL + oidc.JWKSEndpointPath,
SupervisorDiscovery: SupervisorDiscoveryMetadataV1Alpha1{PinnipedIDPsEndpoint: issuerURL + oidc.PinnipedIDPsPathV1Alpha1}, SupervisorDiscovery: SupervisorDiscoveryMetadataV1Alpha1{PinnipedIDPsEndpoint: issuerURL + oidc.PinnipedIDPsPathV1Alpha1},
ResponseTypesSupported: []string{"code"}, ResponseTypesSupported: []string{"code"},
ResponseModesSupported: []string{"query", "form_post"},
SubjectTypesSupported: []string{"public"}, SubjectTypesSupported: []string{"public"},
IDTokenSigningAlgValuesSupported: []string{"ES256"}, IDTokenSigningAlgValuesSupported: []string{"ES256"},
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},

View File

@ -43,6 +43,7 @@ func TestDiscovery(t *testing.T) {
PinnipedIDPsEndpoint: "https://some-issuer.com/some/path/v1alpha1/pinniped_identity_providers", PinnipedIDPsEndpoint: "https://some-issuer.com/some/path/v1alpha1/pinniped_identity_providers",
}, },
ResponseTypesSupported: []string{"code"}, ResponseTypesSupported: []string{"code"},
ResponseModesSupported: []string{"query", "form_post"},
SubjectTypesSupported: []string{"public"}, SubjectTypesSupported: []string{"public"},
IDTokenSigningAlgValuesSupported: []string{"ES256"}, IDTokenSigningAlgValuesSupported: []string{"ES256"},
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},

View File

@ -14,6 +14,7 @@ import (
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
@ -217,7 +218,7 @@ func FositeOauth2Helper(
MinParameterEntropy: fosite.MinParameterEntropy, MinParameterEntropy: fosite.MinParameterEntropy,
} }
return compose.Compose( provider := compose.Compose(
oauthConfig, oauthConfig,
oauthStore, oauthStore,
&compose.CommonStrategy{ &compose.CommonStrategy{
@ -233,6 +234,8 @@ func FositeOauth2Helper(
compose.OAuth2PKCEFactory, compose.OAuth2PKCEFactory,
TokenExchangeFactory, TokenExchangeFactory,
) )
provider.(*fosite.Fosite).FormPostHTMLTemplate = formposthtml.Template()
return provider
} }
// FositeErrorForLog generates a list of information about the provided Fosite error that can be // FositeErrorForLog generates a list of information about the provided Fosite error that can be

View File

@ -0,0 +1,87 @@
/* Copyright 2021 the Pinniped contributors. All Rights Reserved. */
/* SPDX-License-Identifier: Apache-2.0 */
body {
font-family: "Metropolis-Light", Helvetica, sans-serif;
}
h1 {
font-size: 20px;
}
.state {
position: absolute;
top: 100px;
left: 50%;
width: 400px;
height: 80px;
margin-top: -40px;
margin-left: -200px;
font-size: 14px;
line-height: 24px;
}
button {
margin: -10px;
padding: 10px;
text-align: left;
width: 100%;
display: inline;
border: none;
background: none;
cursor: pointer;
transition: all .1s;
}
button:hover {
background-color: #eee;
transform: scale(1.01);
}
button:active {
background-color: #ddd;
transform: scale(.99);
}
code {
word-wrap: break-word;
hyphens: auto;
hyphenate-character: '';
font-size: 12px;
font-family: monospace;
color: #333;
}
.copy-icon {
float: left;
width: 36px;
height: 36px;
padding-top: 2px;
padding-right: 10px;
background-size: contain;
background-repeat: no-repeat;
/*
This is the "copy-to-clipboard-line.svg" icon from Clarity (https://clarity.design/):
https://github.com/vmware/clarity-assets/blob/master/icons/essential/copy-to-clipboard-line.svg
*/
background-image: url("data:image/svg+xml,%3Csvg version='1.1' width='36' height='36' viewBox='0 0 36 36' preserveAspectRatio='xMidYMid meet' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E%3Ctitle%3Ecopy-to-clipboard-line%3C/title%3E%3Cpath d='M22.6,4H21.55a3.89,3.89,0,0,0-7.31,0H13.4A2.41,2.41,0,0,0,11,6.4V10H25V6.4A2.41,2.41,0,0,0,22.6,4ZM23,8H13V6.25A.25.25,0,0,1,13.25,6h2.69l.12-1.11A1.24,1.24,0,0,1,16.61,4a2,2,0,0,1,3.15,1.18l.09.84h2.9a.25.25,0,0,1,.25.25Z' class='clr-i-outline clr-i-outline-path-1'%3E%3C/path%3E%3Cpath d='M33.25,18.06H21.33l2.84-2.83a1,1,0,1,0-1.42-1.42L17.5,19.06l5.25,5.25a1,1,0,0,0,.71.29,1,1,0,0,0,.71-1.7l-2.84-2.84H33.25a1,1,0,0,0,0-2Z' class='clr-i-outline clr-i-outline-path-2'%3E%3C/path%3E%3Cpath d='M29,16h2V6.68A1.66,1.66,0,0,0,29.35,5H27.08V7H29Z' class='clr-i-outline clr-i-outline-path-3'%3E%3C/path%3E%3Cpath d='M29,31H7V7H9V5H6.64A1.66,1.66,0,0,0,5,6.67V31.32A1.66,1.66,0,0,0,6.65,33H29.36A1.66,1.66,0,0,0,31,31.33V22.06H29Z' class='clr-i-outline clr-i-outline-path-4'%3E%3C/path%3E%3Crect x='0' y='0' width='36' height='36' fill-opacity='0'/%3E%3C/svg%3E");
}
@keyframes loader {
to {
transform: rotate(360deg);
}
}
#loading {
content: '';
box-sizing: border-box;
width: 80px;
height: 80px;
margin-top: -40px;
margin-left: -40px;
border-radius: 50%;
border: 2px solid #fff;
border-top-color: #1b3951;
animation: loader .6s linear infinite;
}

View File

@ -0,0 +1,34 @@
<!--
Copyright 2021 the Pinniped contributors. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0
--><!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<style>{{ minifiedCSS }}</style>
<script>{{ minifiedJS }}</script>
<link id="favicon" rel="icon"/>
</head>
<body>
<noscript>
To finish logging in, paste this authorization code into your command-line session: {{ .Parameters.Get "code" }}
</noscript>
<form>
<input type="hidden" name="redirect_uri" value="{{ .RedirURL }}"/>
<input type="hidden" name="encoded_params" value="{{ .Parameters.Encode }}"/>
</form>
<div id="loading" class="state" data-favicon="⏳" data-title="Logging in..." hidden></div>
<div id="success" class="state" data-favicon="✅" data-title="Login succeeded" hidden>
<h1>Login succeeded</h1>
<p>You have successfully logged in. You may now close this tab.</p>
</div>
<div id="manual" class="state" data-favicon="⌛" data-title="Finish your login" hidden>
<h1>Finish your login</h1>
<p>To finish logging in, paste this authorization code into your command-line session:</p>
<button id="manual-copy-button">
<span class="copy-icon"></span>
<code id="manual-auth-code">{{ .Parameters.Get "code" }}</code>
</button>
</div>
</body>
</html>

View File

@ -0,0 +1,54 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
window.onload = () => {
const transitionToState = (id) => {
// Hide all the other ".state" <div>s.
Array.from(document.querySelectorAll('.state')).forEach(e => e.hidden = true);
// Unhide the current state <div>.
const currentDiv = document.getElementById(id)
currentDiv.hidden = false;
// Set the window title.
document.title = currentDiv.dataset.title;
// Set the favicon using inline SVG (does not work on Safari).
document.getElementById('favicon').setAttribute(
'href',
'data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>' +
currentDiv.dataset.favicon +
'</text></svg>'
);
}
// At load, show the spinner, hide the other divs, set the favicon, and
// replace the URL path with './' so the upstream auth code disappears.
transitionToState('loading');
window.history.replaceState(null, '', './');
// When the copy button is clicked, copy to the clipboard.
document.getElementById('manual-copy-button').onclick = () => {
const code = document.getElementById('manual-copy-button').innerText;
navigator.clipboard.writeText(code)
.then(() => console.info('copied authorization code ' + code + ' to clipboard'))
.catch(e => console.error('failed to copy code ' + code + ' to clipboard: ' + e));
};
// Set a timeout to transition to the "manual" state if nothing succeeds within 2s.
const timeout = setTimeout(() => transitionToState('manual'), 2000);
// Try to submit the POST callback, handling the success and error cases.
const responseParams = document.forms[0].elements;
fetch(
responseParams['redirect_uri'].value,
{
method: 'POST',
mode: 'no-cors',
headers: {'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8'},
body: responseParams['encoded_params'].value,
})
.then(() => clearTimeout(timeout))
.then(() => transitionToState('success'))
.catch(() => transitionToState('manual'));
};

View File

@ -0,0 +1,65 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package formposthtml defines HTML templates used by the Supervisor.
//nolint: gochecknoglobals // This package uses globals to ensure that all parsing and minifying happens at init.
package formposthtml
import (
"crypto/sha256"
_ "embed" // Needed to trigger //go:embed directives below.
"encoding/base64"
"html/template"
"strings"
"github.com/tdewolff/minify/v2/minify"
)
var (
//go:embed form_post.css
rawCSS string
minifiedCSS = mustMinify(minify.CSS(rawCSS))
//go:embed form_post.js
rawJS string
minifiedJS = mustMinify(minify.JS(rawJS))
//go:embed form_post.gohtml
rawHTMLTemplate string
)
// Parse the Go templated HTML and inject functions providing the minified inline CSS and JS.
var parsedHTMLTemplate = template.Must(template.New("form_post.gohtml").Funcs(template.FuncMap{
"minifiedCSS": func() template.CSS { return template.CSS(minifiedCSS) },
"minifiedJS": func() template.JS { return template.JS(minifiedJS) }, //nolint:gosec // This is 100% static input, not attacker-controlled.
}).Parse(rawHTMLTemplate))
// Generate the CSP header value once since it's effectively constant:
var cspValue = strings.Join([]string{
`default-src 'none'`,
`script-src '` + cspHash(minifiedJS) + `'`,
`style-src '` + cspHash(minifiedCSS) + `'`,
`img-src data:`,
`connect-src *`,
`frame-ancestors 'none'`,
}, "; ")
func mustMinify(s string, err error) string {
if err != nil {
panic(err)
}
return s
}
func cspHash(s string) string {
hashBytes := sha256.Sum256([]byte(s))
return "sha256-" + base64.StdEncoding.EncodeToString(hashBytes[:])
}
// ContentSecurityPolicy returns the Content-Security-Policy header value to make the Template() operate correctly.
//
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/default-src#:~:text=%27%3Chash-algorithm%3E-%3Cbase64-value%3E%27.
func ContentSecurityPolicy() string { return cspValue }
// Template returns the html/template.Template for rendering the response_type=form_post response page.
func Template() *template.Template { return parsedHTMLTemplate }

View File

@ -0,0 +1,101 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package formposthtml
import (
"bytes"
"fmt"
"net/url"
"testing"
"github.com/ory/fosite"
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/here"
)
var (
testRedirectURL = "http://127.0.0.1:12345/callback"
testResponseParams = url.Values{
"code": []string{"test-S629KHsCCBYV0PQ6FDSrn6iEXtVImQRBh7NCAk.JezyUSdCiSslYjtUmv7V5VAgiCz3ZkES9mYldg9GhqU"},
"scope": []string{"openid offline_access pinniped:request-audience"},
"state": []string{"01234567890123456789012345678901"},
}
testExpectedFormPostOutput = here.Doc(`
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<style>body{font-family:metropolis-light,Helvetica,sans-serif}h1{font-size:20px}.state{position:absolute;top:100px;left:50%;width:400px;height:80px;margin-top:-40px;margin-left:-200px;font-size:14px;line-height:24px}button{margin:-10px;padding:10px;text-align:left;width:100%;display:inline;border:none;background:0 0;cursor:pointer;transition:all .1s}button:hover{background-color:#eee;transform:scale(1.01)}button:active{background-color:#ddd;transform:scale(.99)}code{word-wrap:break-word;hyphens:auto;hyphenate-character:'';font-size:12px;font-family:monospace;color:#333}.copy-icon{float:left;width:36px;height:36px;padding-top:2px;padding-right:10px;background-size:contain;background-repeat:no-repeat;background-image:url("data:image/svg+xml,%3Csvg width=%2236%22 height=%2236%22 viewBox=%220 0 36 36%22 xmlns=%22http://www.w3.org/2000/svg%22 xmlns:xlink=%22http://www.w3.org/1999/xlink%22%3E%3Ctitle%3Ecopy-to-clipboard-line%3C/title%3E%3Cpath d=%22M22.6 4H21.55a3.89 3.89.0 00-7.31.0H13.4A2.41 2.41.0 0011 6.4V10H25V6.4A2.41 2.41.0 0022.6 4zM23 8H13V6.25A.25.25.0 0113.25 6h2.69l.12-1.11A1.24 1.24.0 0116.61 4a2 2 0 013.15 1.18l.09.84h2.9a.25.25.0 01.25.25z%22 class=%22clr-i-outline clr-i-outline-path-1%22/%3E%3Cpath d=%22M33.25 18.06H21.33l2.84-2.83a1 1 0 10-1.42-1.42L17.5 19.06l5.25 5.25a1 1 0 00.71.29 1 1 0 00.71-1.7l-2.84-2.84H33.25a1 1 0 000-2z%22 class=%22clr-i-outline clr-i-outline-path-2%22/%3E%3Cpath d=%22M29 16h2V6.68A1.66 1.66.0 0029.35 5H27.08V7H29z%22 class=%22clr-i-outline clr-i-outline-path-3%22/%3E%3Cpath d=%22M29 31H7V7H9V5H6.64A1.66 1.66.0 005 6.67V31.32A1.66 1.66.0 006.65 33H29.36A1.66 1.66.0 0031 31.33V22.06H29z%22 class=%22clr-i-outline clr-i-outline-path-4%22/%3E%3Crect x=%220%22 y=%220%22 width=%2236%22 height=%2236%22 fill-opacity=%220%22/%3E%3C/svg%3E")}@keyframes loader{to{transform:rotate(360deg)}}#loading{content:'';box-sizing:border-box;width:80px;height:80px;margin-top:-40px;margin-left:-40px;border-radius:50%;border:2px solid #fff;border-top-color:#1b3951;animation:loader .6s linear infinite}</style>
<script>window.onload=()=>{const a=b=>{Array.from(document.querySelectorAll('.state')).forEach(a=>a.hidden=!0);const a=document.getElementById(b);a.hidden=!1,document.title=a.dataset.title,document.getElementById('favicon').setAttribute('href','data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>'+a.dataset.favicon+'</text></svg>')};a('loading'),window.history.replaceState(null,'','./'),document.getElementById('manual-copy-button').onclick=()=>{const a=document.getElementById('manual-copy-button').innerText;navigator.clipboard.writeText(a).then(()=>console.info('copied authorization code '+a+' to clipboard')).catch(b=>console.error('failed to copy code '+a+' to clipboard: '+b))};const c=setTimeout(()=>a('manual'),2e3),b=document.forms[0].elements;fetch(b.redirect_uri.value,{method:'POST',mode:'no-cors',headers:{'Content-Type':'application/x-www-form-urlencoded;charset=UTF-8'},body:b.encoded_params.value}).then(()=>clearTimeout(c)).then(()=>a('success')).catch(()=>a('manual'))}</script>
<link id="favicon" rel="icon"/>
</head>
<body>
<noscript>
To finish logging in, paste this authorization code into your command-line session: test-S629KHsCCBYV0PQ6FDSrn6iEXtVImQRBh7NCAk.JezyUSdCiSslYjtUmv7V5VAgiCz3ZkES9mYldg9GhqU
</noscript>
<form>
<input type="hidden" name="redirect_uri" value="http://127.0.0.1:12345/callback"/>
<input type="hidden" name="encoded_params" value="code=test-S629KHsCCBYV0PQ6FDSrn6iEXtVImQRBh7NCAk.JezyUSdCiSslYjtUmv7V5VAgiCz3ZkES9mYldg9GhqU&amp;scope=openid&#43;offline_access&#43;pinniped%3Arequest-audience&amp;state=01234567890123456789012345678901"/>
</form>
<div id="loading" class="state" data-favicon="⏳" data-title="Logging in..." hidden></div>
<div id="success" class="state" data-favicon="✅" data-title="Login succeeded" hidden>
<h1>Login succeeded</h1>
<p>You have successfully logged in. You may now close this tab.</p>
</div>
<div id="manual" class="state" data-favicon="⌛" data-title="Finish your login" hidden>
<h1>Finish your login</h1>
<p>To finish logging in, paste this authorization code into your command-line session:</p>
<button id="manual-copy-button">
<span class="copy-icon"></span>
<code id="manual-auth-code">test-S629KHsCCBYV0PQ6FDSrn6iEXtVImQRBh7NCAk.JezyUSdCiSslYjtUmv7V5VAgiCz3ZkES9mYldg9GhqU</code>
</button>
</div>
</body>
</html>
`)
// It's okay if this changes in the future, but this gives us a chance to eyeball the formatting.
// Our browser-based integration tests should find any incompatibilities.
testExpectedCSP = `default-src 'none'; ` +
`script-src 'sha256-U+tKnJ2oMSYKSxmSX3V2mPBN8xdr9JpampKAhbSo108='; ` +
`style-src 'sha256-TLAQE3UR2KpwP7AzMCE4iPDizh7zLPx9UXeK5ntuoRg='; ` +
`img-src data:; ` +
`connect-src *; ` +
`frame-ancestors 'none'`
)
func TestTemplate(t *testing.T) {
// Use the Fosite helper to render the form, ensuring that the parameters all have the same names + types.
var buf bytes.Buffer
fosite.WriteAuthorizeFormPostResponse(testRedirectURL, testResponseParams, Template(), &buf)
// Render again so we can confirm that there is no error returned (Fosite ignores any error).
var buf2 bytes.Buffer
require.NoError(t, Template().Execute(&buf2, struct {
RedirURL string
Parameters url.Values
}{
RedirURL: testRedirectURL,
Parameters: testResponseParams,
}))
require.Equal(t, buf.String(), buf2.String())
require.Equal(t, testExpectedFormPostOutput, buf.String())
}
func TestContentSecurityPolicyHashes(t *testing.T) {
require.Equal(t, testExpectedCSP, ContentSecurityPolicy())
}
func TestHelpers(t *testing.T) {
// These are silly tests but it's easy to we might as well have them.
require.Equal(t, "test", mustMinify("test", nil))
require.PanicsWithError(t, "some error", func() { mustMinify("", fmt.Errorf("some error")) })
// Example test vector from https://content-security-policy.com/hash/.
require.Equal(t, "sha256-RFWPLDbv2BY+rCkDzsE+0fr8ylGr2R2faWMhq4lfEQc=", cspHash("doSomething();"))
}

View File

@ -1,4 +1,4 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package testutil package testutil
@ -55,7 +55,9 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret
} }
func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) { func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
require.Equal(t, "default-src 'none'; frame-ancestors 'none'", response.Header().Get("Content-Security-Policy")) // This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use.
require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'")
require.Equal(t, "DENY", response.Header().Get("X-Frame-Options")) require.Equal(t, "DENY", response.Header().Get("X-Frame-Options"))
require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection")) require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))
require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options")) require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options"))

View File

@ -235,10 +235,10 @@ func VerifyECDSAIDToken(
return token return token
} }
func RequireAuthcodeRedirectLocation( func RequireAuthCodeRegexpMatch(
t *testing.T, t *testing.T,
actualRedirectLocation string, actualContent string,
wantRedirectLocationRegexp string, wantRegexp string,
kubeClient *fake.Clientset, kubeClient *fake.Clientset,
secretsClient v1.SecretInterface, secretsClient v1.SecretInterface,
oauthStore fositestoragei.AllFositeStorage, oauthStore fositestoragei.AllFositeStorage,
@ -256,9 +256,9 @@ func RequireAuthcodeRedirectLocation(
t.Helper() t.Helper()
// Assert that Location header matches regular expression. // Assert that Location header matches regular expression.
regex := regexp.MustCompile(wantRedirectLocationRegexp) regex := regexp.MustCompile(wantRegexp)
submatches := regex.FindStringSubmatch(actualRedirectLocation) submatches := regex.FindStringSubmatch(actualContent)
require.Lenf(t, submatches, 2, "no regexp match in actualRedirectLocation: %q", actualRedirectLocation) require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent)
capturedAuthCode := submatches[1] capturedAuthCode := submatches[1]
// fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface

View File

@ -10,6 +10,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"mime" "mime"
"net" "net"
"net/http" "net/http"
@ -17,6 +18,7 @@ import (
"os" "os"
"sort" "sort"
"strings" "strings"
"syscall"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
@ -87,6 +89,7 @@ type handlerState struct {
// Generated parameters of a login flow. // Generated parameters of a login flow.
provider *oidc.Provider provider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
useFormPost bool
state state.State state state.State
nonce nonce.Nonce nonce nonce.Nonce
pkce pkce.Code pkce pkce.Code
@ -96,10 +99,12 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
listen func(string, string) (net.Listener, error)
isTTY func(int) bool
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
promptForValue func(promptLabel string) (string, error) promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error) promptForSecret func(ctx context.Context, promptLabel string) (string, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -156,6 +161,9 @@ func WithScopes(scopes []string) Option {
// WithBrowserOpen overrides the default "open browser" functionality with a custom callback. If not specified, // WithBrowserOpen overrides the default "open browser" functionality with a custom callback. If not specified,
// an implementation using https://github.com/pkg/browser will be used by default. // an implementation using https://github.com/pkg/browser will be used by default.
//
// Deprecated: this option will be removed in a future version of Pinniped. See the
// WithSkipBrowserOpen() option instead.
func WithBrowserOpen(openURL func(url string) error) Option { func WithBrowserOpen(openURL func(url string) error) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.openURL = openURL h.openURL = openURL
@ -163,6 +171,23 @@ func WithBrowserOpen(openURL func(url string) error) Option {
} }
} }
// WithSkipBrowserOpen causes the login to only print the authorize URL, but skips attempting to
// open the user's default web browser.
func WithSkipBrowserOpen() Option {
return func(h *handlerState) error {
h.openURL = func(_ string) error { return nil }
return nil
}
}
// WithSkipListen causes the login skip starting the localhost listener, forcing the manual copy/paste login flow.
func WithSkipListen() Option {
return func(h *handlerState) error {
h.listen = func(string, string) (net.Listener, error) { return nil, nil }
return nil
}
}
// SessionCacheKey contains the data used to select a valid session cache entry. // SessionCacheKey contains the data used to select a valid session cache entry.
type SessionCacheKey struct { type SessionCacheKey struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
@ -250,6 +275,8 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generateNonce: nonce.Generate, generateNonce: nonce.Generate,
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
listen: net.Listen,
isTTY: term.IsTerminal,
getProvider: upstreamoidc.New, getProvider: upstreamoidc.New,
validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) { validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token) return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token)
@ -376,11 +403,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error. // and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
// Ask the user for their username and password. // Ask the user for their username and password.
username, err := h.promptForValue(defaultLDAPUsernamePrompt) username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt)
if err != nil { if err != nil {
return nil, fmt.Errorf("error prompting for username: %w", err) return nil, fmt.Errorf("error prompting for username: %w", err)
} }
password, err := h.promptForSecret(defaultLDAPPasswordPrompt) password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
if err != nil { if err != nil {
return nil, fmt.Errorf("error prompting for password: %w", err) return nil, fmt.Errorf("error prompting for password: %w", err)
} }
@ -475,30 +502,55 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (
// Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint. // Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint.
// Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error. // Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error.
func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
// Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number). // Attempt to open a local TCP listener, logging but otherwise ignoring any error.
listener, err := net.Listen("tcp", h.listenAddr) listener, err := h.listen("tcp", h.listenAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err) h.logger.V(debugLogLevel).Error(err, "could not open callback listener")
}
// If the listener failed to start and stdin is not a TTY, then we have no hope of succeeding,
// since we won't be able to receive the web callback and we can't prompt for the manual auth code.
if listener == nil && !h.isTTY(syscall.Stdin) {
return nil, fmt.Errorf("login failed: must have either a localhost listener or stdin must be a TTY")
}
// Update the OAuth2 redirect_uri to match the actual listener address (if there is one), or just use
// a fake ":0" port if there is no listener running.
redirectURI := url.URL{Scheme: "http", Path: h.callbackPath}
if listener == nil {
redirectURI.Host = "127.0.0.1:0"
} else {
redirectURI.Host = listener.Addr().String()
}
h.oauth2Config.RedirectURL = redirectURI.String()
// If the server supports it, request response_mode=form_post.
authParams := *authorizeOptions
if h.useFormPost {
authParams = append(authParams, oauth2.SetAuthURLParam("response_mode", "form_post"))
} }
h.oauth2Config.RedirectURL = (&url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: h.callbackPath,
}).String()
// Now that we have a redirect URL with the listener port, we can build the authorize URL. // Now that we have a redirect URL with the listener port, we can build the authorize URL.
authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), *authorizeOptions...) authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...)
// Start a callback server in a background goroutine. // If there is a listener running, start serving the callback handler in a background goroutine.
if listener != nil {
shutdown := h.serve(listener) shutdown := h.serve(listener)
defer shutdown() defer shutdown()
// Open the authorize URL in the users browser.
if err := h.openURL(authorizeURL); err != nil {
return nil, fmt.Errorf("could not open browser: %w", err)
} }
// Wait for either the callback or a timeout. // Open the authorize URL in the users browser, logging but otherwise ignoring any error.
if err := h.openURL(authorizeURL); err != nil {
h.logger.V(debugLogLevel).Error(err, "could not open browser")
}
ctx, cancel := context.WithCancel(h.ctx)
defer cancel()
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
// Wait for either the web callback, a pasted auth code, or a timeout.
select { select {
case <-h.ctx.Done(): case <-h.ctx.Done():
return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err()) return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err())
@ -510,7 +562,37 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
} }
} }
func promptForValue(promptLabel string) (string, error) { func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it.
if !h.isTTY(syscall.Stdin) {
return
}
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
// code because the user isn't going to have any easy way to manually copy it anyway.
if !h.useFormPost {
return
}
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
// if the parent context is cancelled (when the login succeeds or times out).
go func() {
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
if err != nil {
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return
}
// When a code is pasted, redeem it for a token and return that result on the callbacks channel.
token, err := h.redeemAuthCode(ctx, code)
h.callbacks <- callbackResult{token: token, err: err}
}()
}
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) { if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
@ -518,6 +600,15 @@ func promptForValue(promptLabel string) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
}()
text, err := bufio.NewReader(os.Stdin).ReadString('\n') text, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil { if err != nil {
return "", fmt.Errorf("could read input from stdin: %w", err) return "", fmt.Errorf("could read input from stdin: %w", err)
@ -526,7 +617,7 @@ func promptForValue(promptLabel string) (string, error) {
return text, nil return text, nil
} }
func promptForSecret(promptLabel string) (string, error) { func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) { if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
@ -534,16 +625,26 @@ func promptForSecret(promptLabel string) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
password, err := term.ReadPassword(0)
if err != nil { // If the context is canceled, set the read deadline on stdin so the read immediately finishes.
return "", fmt.Errorf("could not read password: %w", err) ctx, cancel := context.WithCancel(ctx)
} defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
// term.ReadPassword swallows the newline that was typed by the user, so to // term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password // avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline. // prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n") //
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
// on stderr is formatted correctly.
_, _ = fmt.Fprint(os.Stderr, "\n")
}()
password, err := term.ReadPassword(syscall.Stdin)
if err != nil { if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err) return "", fmt.Errorf("could not read password: %w", err)
} }
return string(password), err return string(password), err
} }
@ -567,9 +668,27 @@ func (h *handlerState) initOIDCDiscovery() error {
Endpoint: h.provider.Endpoint(), Endpoint: h.provider.Endpoint(),
Scopes: h.scopes, Scopes: h.scopes,
} }
// Use response_mode=form_post if the provider supports it.
var discoveryClaims struct {
ResponseModesSupported []string `json:"response_modes_supported"`
}
if err := h.provider.Claims(&discoveryClaims); err != nil {
return fmt.Errorf("could not decode response_modes_supported in OIDC discovery from %q: %w", h.issuer, err)
}
h.useFormPost = stringSliceContains(discoveryClaims.ResponseModesSupported, "form_post")
return nil return nil
} }
func stringSliceContains(slice []string, s string) bool {
for _, item := range slice {
if item == s {
return true
}
}
return false
}
func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) { func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) {
h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience) h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience)
// Perform OIDC discovery. This may have already been performed if there was not a cached base token. // Perform OIDC discovery. This may have already been performed if there was not a cached base token.
@ -664,13 +783,29 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
} }
}() }()
var params url.Values
if h.useFormPost {
// Return HTTP 405 for anything that's not a POST.
if r.Method != http.MethodPost {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST")
}
// Parse and pull the response parameters from a application/x-www-form-urlencoded request body.
if err := r.ParseForm(); err != nil {
return httperr.Wrap(http.StatusBadRequest, "invalid form", err)
}
params = r.Form
} else {
// Return HTTP 405 for anything that's not a GET. // Return HTTP 405 for anything that's not a GET.
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
} }
// Pull response parameters from the URL query string.
params = r.URL.Query()
}
// Validate OAuth2 state and fail if it's incorrect (to block CSRF). // Validate OAuth2 state and fail if it's incorrect (to block CSRF).
params := r.URL.Query()
if err := h.state.Validate(params.Get("state")); err != nil { if err := h.state.Validate(params.Get("state")); err != nil {
return httperr.New(http.StatusForbidden, "missing or invalid state parameter") return httperr.New(http.StatusForbidden, "missing or invalid state parameter")
} }
@ -685,14 +820,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
// Exchange the authorization code for access, ID, and refresh tokens and perform required // Exchange the authorization code for access, ID, and refresh tokens and perform required
// validations on the returned ID token. // validations on the returned ID token.
token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). token, err := h.redeemAuthCode(r.Context(), params.Get("code"))
ExchangeAuthcodeAndValidateTokens(
r.Context(),
params.Get("code"),
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }
@ -702,6 +830,17 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
return nil return nil
} }
func (h *handlerState) redeemAuthCode(ctx context.Context, code string) (*oidctypes.Token, error) {
return h.getProvider(h.oauth2Config, h.provider, h.httpClient).
ExchangeAuthcodeAndValidateTokens(
ctx,
code,
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
}
func (h *handlerState) serve(listener net.Listener) func() { func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))

View File

@ -4,15 +4,18 @@
package oidcclient package oidcclient
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"syscall"
"testing" "testing"
"time" "time"
@ -80,6 +83,22 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns discovery data with a broken response_modes_supported value.
brokenResponseModeMux := http.NewServeMux()
brokenResponseModeServer := httptest.NewServer(brokenResponseModeMux)
brokenResponseModeMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
ResponseModesSupported string `json:"response_modes_supported"` // Wrong type (should be []string).
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: brokenResponseModeServer.URL,
ResponseModesSupported: "invalid",
})
})
t.Cleanup(brokenResponseModeServer.Close)
// Start a test server that returns discovery data with a broken token URL // Start a test server that returns discovery data with a broken token URL
brokenTokenURLMux := http.NewServeMux() brokenTokenURLMux := http.NewServeMux()
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux) brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
@ -100,30 +119,29 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
}) })
t.Cleanup(brokenTokenURLServer.Close) t.Cleanup(brokenTokenURLServer.Close)
// Start a test server that returns a real discovery document and answers refresh requests. discoveryHandler := func(server *httptest.Server, responseModes []string) http.HandlerFunc {
providerMux := http.NewServeMux() return func(w http.ResponseWriter, r *http.Request) {
successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed) http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return return
} }
w.Header().Set("content-type", "application/json") w.Header().Set("content-type", "application/json")
type providerJSON struct { _ = json.NewEncoder(w).Encode(&struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"` AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"` TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"` JWKSURL string `json:"jwks_uri"`
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
}{
Issuer: server.URL,
AuthURL: server.URL + "/authorize",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/keys",
ResponseModesSupported: responseModes,
})
} }
_ = json.NewEncoder(w).Encode(&providerJSON{ }
Issuer: successServer.URL, tokenHandler := func(w http.ResponseWriter, r *http.Request) {
AuthURL: successServer.URL + "/authorize",
TokenURL: successServer.URL + "/token",
JWKSURL: successServer.URL + "/keys",
})
})
providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed) http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return return
@ -204,7 +222,21 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
w.Header().Set("content-type", "application/json") w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response)) require.NoError(t, json.NewEncoder(w).Encode(&response))
}) }
// Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(successServer, nil))
providerMux.HandleFunc("/token", tokenHandler)
// Start a test server that returns a real discovery document and answers refresh requests, _and_ supports form_mode=post.
formPostProviderMux := http.NewServeMux()
formPostSuccessServer := httptest.NewServer(formPostProviderMux)
t.Cleanup(formPostSuccessServer.Close)
formPostProviderMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(formPostSuccessServer, []string{"query", "form_post"}))
formPostProviderMux.HandleFunc("/token", tokenHandler)
defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam
// Call the handler function from the test server to calculate the response. // Call the handler function from the test server to calculate the response.
@ -218,8 +250,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil } h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil } h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil } h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil} cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -349,7 +381,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
name: "discovery failure", name: "discovery failure due to 500 error",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil } return func(h *handlerState) error { return nil }
}, },
@ -357,6 +389,15 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""},
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL), wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
}, },
{
name: "discovery failure due to invalid response_modes_supported",
opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil }
},
issuer: brokenResponseModeServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenResponseModeServer.URL + "\""},
wantErr: fmt.Sprintf("could not decode response_modes_supported in OIDC discovery from %q: json: cannot unmarshal string into Go struct field .response_modes_supported of type []string", brokenResponseModeServer.URL),
},
{ {
name: "session cache hit with refreshable token", name: "session cache hit with refreshable token",
issuer: successServer.URL, issuer: successServer.URL,
@ -451,38 +492,93 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
}) })
h.cache = cache h.cache = cache
h.listenAddr = "invalid-listen-address" h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(int) bool { return false }
return nil return nil
} }
}, },
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"", wantLogs: []string{
"\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""}, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
`"level"=4 "msg"="Pinniped: Refreshing cached token."`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
// Expect this to fall through to the authorization code flow, so it fails here. // Expect this to fall through to the authorization code flow, so it fails here.
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
}, },
{ {
name: "listen failure", name: "listen failure and non-tty stdin",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.listenAddr = "invalid-listen-address" h.listen = func(net string, addr string) (net.Listener, error) {
assert.Equal(t, "tcp", net)
assert.Equal(t, "localhost:0", addr)
return nil, fmt.Errorf("some listen error")
}
h.isTTY = func(fd int) bool {
assert.Equal(t, fd, syscall.Stdin)
return false
}
return nil return nil
} }
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
}, },
{ {
name: "browser open failure", name: "listening disabled and manual prompt fails",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return WithBrowserOpen(func(url string) error { return func(h *handlerState) error {
require.NoError(t, WithSkipListen()(h))
h.isTTY = func(fd int) bool { return true }
h.openURL = func(authorizeURL string) error {
parsed, err := url.Parse(authorizeURL)
require.NoError(t, err)
require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error") return fmt.Errorf("some browser open error")
}) }
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
}
}, },
issuer: successServer.URL, issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{
wantErr: "could not open browser: some browser open error", `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open browser" "error"="some browser open error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "listen success and manual prompt succeeds",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(fd int) bool { return true }
h.openURL = func(authorizeURL string) error {
parsed, err := url.Parse(authorizeURL)
require.NoError(t, err)
require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
}, },
{ {
name: "timeout waiting for callback", name: "timeout waiting for callback",
@ -580,6 +676,68 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken, wantToken: &testToken,
}, },
{
name: "callback returns success with request_mode=form_post",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: formPostSuccessServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err)
actualParams := parsedActualURL.Query()
require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"response_mode": []string{"form_post"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
}, actualParams)
parsedActualURL.RawQuery = ""
require.Equal(t, formPostSuccessServer.URL+"/authorize", parsedActualURL.String())
go func() {
h.callbacks <- callbackResult{token: &testToken}
}()
return nil
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
wantToken: &testToken,
},
{ {
name: "upstream name and type are included in authorize request if upstream name is provided", name: "upstream name and type are included in authorize request if upstream name is provided",
clientID: "test-client-id", clientID: "test-client-id",
@ -650,7 +808,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error") return "", errors.New("some prompt error")
} }
@ -667,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") } h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
return nil return nil
} }
}, },
@ -853,11 +1011,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil } h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil return "some-upstream-username", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel) require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil return "some-upstream-password", nil
} }
@ -1287,10 +1445,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
WithContext(context.Background()), WithContext(context.Background()),
WithListenPort(0), WithListenPort(0),
WithScopes([]string{"test-scope"}), WithScopes([]string{"test-scope"}),
WithSkipBrowserOpen(),
tt.opt(t), tt.opt(t),
WithLogger(testLogger), WithLogger(testLogger),
) )
require.Equal(t, tt.wantLogs, testLogger.Lines()) testLogger.Expect(tt.wantLogs)
if tt.wantErr != "" { if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr) require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok) require.Nil(t, tok)
@ -1324,13 +1483,152 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
} }
} }
func TestHandleAuthCodeCallback(t *testing.T) { func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback" const testRedirectURI = "http://127.0.0.1:12324/callback"
tests := []struct {
name string
opt func(t *testing.T) Option
wantCallback *callbackResult
}{
{
name: "no stdin available",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool {
require.Equal(t, syscall.Stdin, fd)
return false
}
h.useFormPost = true
return nil
}
},
},
{
name: "no form_post mode available",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = false
return nil
}
},
},
{
name: "prompt fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
wantCallback: &callbackResult{
err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"),
},
},
{
name: "redeeming code fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(nil, fmt.Errorf("some exchange error"))
return mock
}
return nil
}
},
wantCallback: &callbackResult{
err: fmt.Errorf("some exchange error"),
},
},
{
name: "success",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
wantCallback: &callbackResult{
token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var buf bytes.Buffer
h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
require.Equal(t,
"Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
buf.String(),
)
if tt.wantCallback != nil {
select {
case <-time.After(1 * time.Second):
require.Fail(t, "timed out waiting to receive from callbacks channel")
case result := <-h.callbacks:
require.Equal(t, *tt.wantCallback, result)
}
}
})
}
}
func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
withFormPostMode := func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
return nil
}
}
tests := []struct { tests := []struct {
name string name string
method string method string
query string query string
body []byte
contentType string
opt func(t *testing.T) Option opt func(t *testing.T) Option
wantErr string wantErr string
wantHTTPStatus int wantHTTPStatus int
@ -1342,6 +1640,24 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantErr: "wanted GET", wantErr: "wanted GET",
wantHTTPStatus: http.StatusMethodNotAllowed, wantHTTPStatus: http.StatusMethodNotAllowed,
}, },
{
name: "wrong method for form_post",
method: "GET",
query: "",
opt: withFormPostMode,
wantErr: "wanted POST",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "invalid form for form_post",
method: "POST",
query: "",
contentType: "application/x-www-form-urlencoded",
body: []byte(`%`),
opt: withFormPostMode,
wantErr: `invalid form: invalid URL escape "%"`,
wantHTTPStatus: http.StatusBadRequest,
},
{ {
name: "invalid state", name: "invalid state",
query: "state=invalid", query: "state=invalid",
@ -1396,6 +1712,26 @@ func TestHandleAuthCodeCallback(t *testing.T) {
} }
}, },
}, },
{
name: "valid form_post",
method: http.MethodPost,
contentType: "application/x-www-form-urlencoded",
body: []byte(`state=test-state&code=valid`),
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
@ -1414,12 +1750,15 @@ func TestHandleAuthCodeCallback(t *testing.T) {
defer cancel() defer cancel()
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", nil) req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body))
require.NoError(t, err) require.NoError(t, err)
req.URL.RawQuery = tt.query req.URL.RawQuery = tt.query
if tt.method != "" { if tt.method != "" {
req.Method = tt.method req.Method = tt.method
} }
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
err = h.handleAuthCodeCallback(resp, req) err = h.handleAuthCodeCallback(resp, req)
if tt.wantErr != "" { if tt.wantErr != "" {

View File

@ -307,16 +307,15 @@ func runPinnipedLoginOIDC(
reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr)) reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr))
scanner := bufio.NewScanner(reader) scanner := bufio.NewScanner(reader)
const prompt = "Please log in: "
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() loginURL, err := url.Parse(strings.TrimSpace(scanner.Text()))
if strings.HasPrefix(line, prompt) { if err == nil && loginURL.Scheme == "https" {
loginURLChan <- strings.TrimPrefix(line, prompt) loginURLChan <- loginURL.String()
return nil return nil
} }
} }
return fmt.Errorf("expected stderr to contain %s", prompt) return fmt.Errorf("expected stderr to contain login URL")
}) })
// Start a background goroutine to read stdout from the CLI and parse out an ExecCredential. // Start a background goroutine to read stdout from the CLI and parse out an ExecCredential.

View File

@ -109,7 +109,7 @@ func TestE2EFullIntegration(t *testing.T) {
}) })
// Add an OIDC upstream IDP and try using it to authenticate during kubectl commands. // Add an OIDC upstream IDP and try using it to authenticate during kubectl commands.
t.Run("with Supervisor OIDC upstream IDP", func(t *testing.T) { t.Run("with Supervisor OIDC upstream IDP and automatic flow", func(t *testing.T) {
expectedUsername := env.SupervisorUpstreamOIDC.Username expectedUsername := env.SupervisorUpstreamOIDC.Username
expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups
@ -195,16 +195,15 @@ func TestE2EFullIntegration(t *testing.T) {
}() }()
reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe)) reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe))
line, err := reader.ReadString('\n') scanner := bufio.NewScanner(reader)
if err != nil { for scanner.Scan() {
return fmt.Errorf("could not read login URL line from stderr: %w", err) loginURL, err := url.Parse(strings.TrimSpace(scanner.Text()))
if err == nil && loginURL.Scheme == "https" {
loginURLChan <- loginURL.String()
return nil
} }
const prompt = "Please log in: "
if !strings.HasPrefix(line, prompt) {
return fmt.Errorf("expected %q to have prefix %q", line, prompt)
} }
loginURLChan <- strings.TrimPrefix(line, prompt) return fmt.Errorf("expected stderr to contain login URL")
return readAndExpectEmpty(reader)
}) })
// Start a background goroutine to read stdout from kubectl and return the result as a string. // Start a background goroutine to read stdout from kubectl and return the result as a string.
@ -242,17 +241,13 @@ func TestE2EFullIntegration(t *testing.T) {
// Expect to be redirected to the upstream provider and log in. // Expect to be redirected to the upstream provider and log in.
browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC) browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC)
// Expect to be redirected to the localhost callback. // Expect to be redirected to the downstream callback which is serving the form_post HTML.
t.Logf("waiting for redirect to callback") t.Logf("waiting for response page %s", downstream.Spec.Issuer)
browsertest.WaitForURL(t, page, regexp.MustCompile(`\Ahttp://127\.0\.0\.1:[0-9]+/callback\?.+\z`)) browsertest.WaitForURL(t, page, regexp.MustCompile(regexp.QuoteMeta(downstream.Spec.Issuer)))
// Wait for the "pre" element that gets rendered for a `text/plain` page, and // The response page should have done the background fetch() and POST'ed to the CLI's callback.
// assert that it contains the success message. // It should now be in the "success" state.
t.Logf("verifying success page") formpostExpectSuccessState(t, page)
browsertest.WaitForVisibleElements(t, page, "pre")
msg, err := page.First("pre").Text()
require.NoError(t, err)
require.Equal(t, "you have been logged in and may now close this tab", msg)
// Expect the CLI to output a list of namespaces in JSON format. // Expect the CLI to output a list of namespaces in JSON format.
t.Logf("waiting for kubectl to output namespace list JSON") t.Logf("waiting for kubectl to output namespace list JSON")
@ -275,6 +270,113 @@ func TestE2EFullIntegration(t *testing.T) {
) )
}) })
t.Run("with Supervisor OIDC upstream IDP and manual flow", func(t *testing.T) {
expectedUsername := env.SupervisorUpstreamOIDC.Username
expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups
// Create a ClusterRoleBinding to give our test user from the upstream read-only access to the cluster.
testlib.CreateTestClusterRoleBinding(t,
rbacv1.Subject{Kind: rbacv1.UserKind, APIGroup: rbacv1.GroupName, Name: expectedUsername},
rbacv1.RoleRef{Kind: "ClusterRole", APIGroup: rbacv1.GroupName, Name: "view"},
)
testlib.WaitForUserToHaveAccess(t, expectedUsername, []string{}, &authorizationv1.ResourceAttributes{
Verb: "get",
Group: "",
Version: "v1",
Resource: "namespaces",
})
// Create upstream OIDC provider and wait for it to become ready.
testlib.CreateTestOIDCIdentityProvider(t, idpv1alpha1.OIDCIdentityProviderSpec{
Issuer: env.SupervisorUpstreamOIDC.Issuer,
TLS: &idpv1alpha1.TLSSpec{
CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamOIDC.CABundle)),
},
AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{
AdditionalScopes: env.SupervisorUpstreamOIDC.AdditionalScopes,
},
Claims: idpv1alpha1.OIDCClaims{
Username: env.SupervisorUpstreamOIDC.UsernameClaim,
Groups: env.SupervisorUpstreamOIDC.GroupsClaim,
},
Client: idpv1alpha1.OIDCClient{
SecretName: testlib.CreateClientCredsSecret(t, env.SupervisorUpstreamOIDC.ClientID, env.SupervisorUpstreamOIDC.ClientSecret).Name,
},
}, idpv1alpha1.PhaseReady)
// Use a specific session cache for this test.
sessionCachePath := tempDir + "/oidc-test-sessions-manual.yaml"
kubeconfigPath := runPinnipedGetKubeconfig(t, env, pinnipedExe, tempDir, []string{
"get", "kubeconfig",
"--concierge-api-group-suffix", env.APIGroupSuffix,
"--concierge-authenticator-type", "jwt",
"--concierge-authenticator-name", authenticator.Name,
"--oidc-skip-browser",
"--oidc-skip-listen",
"--oidc-ca-bundle", testCABundlePath,
"--oidc-session-cache", sessionCachePath,
})
// Run "kubectl get namespaces" which should trigger a browser login via the plugin.
start := time.Now()
kubectlCmd := exec.CommandContext(ctx, "kubectl", "get", "namespace", "--kubeconfig", kubeconfigPath)
kubectlCmd.Env = append(os.Environ(), env.ProxyEnv()...)
ptyFile, err := pty.Start(kubectlCmd)
require.NoError(t, err)
// Wait for the subprocess to print the login prompt.
t.Logf("waiting for CLI to output login URL and manual prompt")
output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ")
require.Contains(t, output, "Log in by visiting this link:")
require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ")
// Find the line with the login URL.
var loginURL string
for _, line := range strings.Split(output, "\n") {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "https://") {
loginURL = trimmed
}
}
require.NotEmptyf(t, loginURL, "didn't find login URL in output: %s", output)
t.Logf("navigating to login page")
require.NoError(t, page.Navigate(loginURL))
// Expect to be redirected to the upstream provider and log in.
browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC)
// Expect to be redirected to the downstream callback which is serving the form_post HTML.
t.Logf("waiting for response page %s", downstream.Spec.Issuer)
browsertest.WaitForURL(t, page, regexp.MustCompile(regexp.QuoteMeta(downstream.Spec.Issuer)))
// The response page should have failed to automatically post, and should now be showing the manual instructions.
authCode := formpostExpectManualState(t, page)
// Enter the auth code in the waiting prompt, followed by a newline.
t.Logf("'manually' pasting authorization code %q to waiting prompt", authCode)
_, err = ptyFile.WriteString(authCode + "\n")
require.NoError(t, err)
// Read all of the remaining output from the subprocess until EOF.
t.Logf("waiting for kubectl to output namespace list")
remainingOutput, _ := ioutil.ReadAll(ptyFile)
// Ignore any errors returned because there is always an error on linux.
require.Greaterf(t, len(remainingOutput), 0, "expected to get some more output from the kubectl subcommand, but did not")
require.Greaterf(t, len(strings.Split(string(remainingOutput), "\n")), 2, "expected some namespaces to be returned, got %q", string(remainingOutput))
t.Logf("first kubectl command took %s", time.Since(start).String())
requireUserCanUseKubectlWithoutAuthenticatingAgain(ctx, t, env,
downstream,
kubeconfigPath,
sessionCachePath,
pinnipedExe,
expectedUsername,
expectedGroups,
)
})
// Add an LDAP upstream IDP and try using it to authenticate during kubectl commands. // Add an LDAP upstream IDP and try using it to authenticate during kubectl commands.
t.Run("with Supervisor LDAP upstream IDP", func(t *testing.T) { t.Run("with Supervisor LDAP upstream IDP", func(t *testing.T) {
if len(env.ToolsNamespace) == 0 && !env.HasCapability(testlib.CanReachInternetLDAPPorts) { if len(env.ToolsNamespace) == 0 && !env.HasCapability(testlib.CanReachInternetLDAPPorts) {
@ -376,7 +478,7 @@ func TestE2EFullIntegration(t *testing.T) {
}) })
} }
func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) { func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) string {
readFromFile := "" readFromFile := ""
testlib.RequireEventuallyWithoutError(t, func() (bool, error) { testlib.RequireEventuallyWithoutError(t, func() (bool, error) {
@ -390,6 +492,7 @@ func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) {
} }
return false, nil // keep waiting and reading return false, nil // keep waiting and reading
}, 1*time.Minute, 1*time.Second) }, 1*time.Minute, 1*time.Second)
return readFromFile
} }
func readAvailableOutput(t *testing.T, r io.Reader) (string, bool) { func readAvailableOutput(t *testing.T, r io.Reader) (string, bool) {

View File

@ -0,0 +1,257 @@
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package integration
import (
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/token/hmac"
"github.com/sclevine/agouti"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/test/testlib"
"go.pinniped.dev/test/testlib/browsertest"
)
func TestFormPostHTML(t *testing.T) {
// Run a mock callback handler, simulating the one running in the CLI.
callbackURL, expectCallback := formpostCallbackServer(t)
// Open a single browser for all subtests to use (in sequence).
page := browsertest.Open(t)
t.Run("success", func(t *testing.T) {
// Serve the form_post template with successful parameters.
responseParams := formpostRandomParams(t)
formpostInitiate(t, page, formpostTemplateServer(t, callbackURL, responseParams))
// Now we handle the callback and assert that we got what we expected. This should transition
// the UI into the success state.
expectCallback(t, responseParams)
formpostExpectSuccessState(t, page)
})
t.Run("callback server error", func(t *testing.T) {
// Serve the form_post template with a redirect URI that will return an HTTP 500 response.
responseParams := formpostRandomParams(t)
formpostInitiate(t, page, formpostTemplateServer(t, callbackURL+"?fail=500", responseParams))
// Now we handle the callback and assert that we got what we expected.
expectCallback(t, responseParams)
// This is not 100% the behavior we'd like, but because our JS is making
// a cross-origin fetch() without CORS, we don't get to know anything
// about the response (even whether it is 200 vs. 500), so this case
// is the same as the success case.
//
// This case is fairly unlikely in practice, and if the CLI encounters
// an error it can also expose it via stderr anyway.
formpostExpectSuccessState(t, page)
})
t.Run("network failure", func(t *testing.T) {
// Serve the form_post template with a redirect URI that will return a network error.
responseParams := formpostRandomParams(t)
formpostInitiate(t, page, formpostTemplateServer(t, callbackURL+"?fail=close", responseParams))
// Now we handle the callback and assert that we got what we expected.
// This will trigger the callback server to close the client connection abruptly because
// of the `?fail=close` parameter above.
expectCallback(t, responseParams)
// This failure should cause the UI to enter the "manual" state.
actualCode := formpostExpectManualState(t, page)
require.Equal(t, responseParams.Get("code"), actualCode)
})
t.Run("timeout", func(t *testing.T) {
// Serve the form_post template with successful parameters.
responseParams := formpostRandomParams(t)
formpostInitiate(t, page, formpostTemplateServer(t, callbackURL, responseParams))
// Sleep for longer than the two second timeout.
// During this sleep we are blocking the callback from returning.
time.Sleep(3 * time.Second)
// Assert that the timeout fires and we see the manual instructions.
actualCode := formpostExpectManualState(t, page)
require.Equal(t, responseParams.Get("code"), actualCode)
// Now simulate the callback finally succeeding, in which case
// the manual instructions should disappear and we should see the success
// div instead.
expectCallback(t, responseParams)
formpostExpectSuccessState(t, page)
})
}
// formpostCallbackServer runs a test server that simulates the CLI's callback handler.
// It returns the URL of the running test server and a function for fetching the next
// received form POST parameters.
//
// The test server supports special `?fail=close` and `?fail=500` to force error cases.
func formpostCallbackServer(t *testing.T) (string, func(*testing.T, url.Values)) {
results := make(chan url.Values)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NoError(t, r.ParseForm())
// Extract only the POST parameters (r.Form also contains URL query parameters).
postParams := url.Values{}
for k := range r.Form {
if v := r.PostFormValue(k); v != "" {
postParams.Set(k, v)
}
}
// Send the form parameters back on the results channel, giving up if the
// request context is cancelled (such as if the client disconnects).
select {
case results <- postParams:
case <-r.Context().Done():
return
}
switch r.URL.Query().Get("fail") {
case "close": // If "fail=close" is passed, close the connection immediately.
if conn, _, err := w.(http.Hijacker).Hijack(); err == nil {
_ = conn.Close()
}
return
case "500": // If "fail=500" is passed, return a 500 error.
w.WriteHeader(http.StatusInternalServerError)
return
}
}))
t.Cleanup(func() {
close(results)
server.Close()
})
return server.URL, func(t *testing.T, expected url.Values) {
t.Logf("expecting to get a POST callback...")
select {
case actual := <-results:
require.Equal(t, expected, actual, "did not receive expected callback")
case <-time.After(3 * time.Second):
t.Errorf("failed to receive expected callback %v", expected)
t.FailNow()
}
}
}
// formpostTemplateServer runs a test server that serves formposthtml.Template() rendered with test parameters.
func formpostTemplateServer(t *testing.T, redirectURI string, responseParams url.Values) string {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fosite.WriteAuthorizeFormPostResponse(redirectURI, responseParams, formposthtml.Template(), w)
})
server := httptest.NewServer(securityheader.WrapWithCustomCSP(
handler,
formposthtml.ContentSecurityPolicy(),
))
t.Cleanup(server.Close)
return server.URL
}
// formpostRandomParams is a helper to generate random OAuth2 response parameters for testing.
func formpostRandomParams(t *testing.T) url.Values {
generator := &hmac.HMACStrategy{GlobalSecret: testlib.RandBytes(t, 32), TokenEntropy: 32}
authCode, _, err := generator.Generate()
require.NoError(t, err)
return url.Values{
"code": []string{authCode},
"scope": []string{"openid offline_access pinniped:request-audience"},
"state": []string{testlib.RandHex(t, 16)},
}
}
// formpostExpectTitle asserts that the page has the expected title.
func formpostExpectTitle(t *testing.T, page *agouti.Page, expected string) {
actual, err := page.Title()
require.NoError(t, err)
require.Equal(t, expected, actual)
}
// formpostExpectTitle asserts that the page has the expected SVG/emoji favicon.
func formpostExpectFavicon(t *testing.T, page *agouti.Page, expected string) {
iconURL, err := page.First("#favicon").Attribute("href")
require.NoError(t, err)
require.True(t, strings.HasPrefix(iconURL, "data:image/svg+xml,<svg"))
// For some reason chromedriver on Linux returns this attribute urlencoded, but on macOS it contains the
// original emoji bytes (unescaped). To check correctly in both cases we allow either version here.
expectedEscaped := url.QueryEscape(expected)
require.Truef(t,
strings.Contains(iconURL, expected) || strings.Contains(iconURL, expectedEscaped),
"expected %q to contain %q or %q", iconURL, expected, expectedEscaped,
)
}
// formpostInitiate navigates to the template server endpoint and expects the
// loading animation to be shown.
func formpostInitiate(t *testing.T, page *agouti.Page, url string) {
require.NoError(t, page.Reset())
t.Logf("navigating to mock form_post template URL %s...", url)
require.NoError(t, page.Navigate(url))
t.Logf("expecting to see loading animation...")
browsertest.WaitForVisibleElements(t, page, "#loading")
formpostExpectTitle(t, page, "Logging in...")
formpostExpectFavicon(t, page, "⏳")
}
// formpostExpectSuccessState asserts that the page is in the "success" state.
func formpostExpectSuccessState(t *testing.T, page *agouti.Page) {
t.Logf("expecting to see success message become visible...")
browsertest.WaitForVisibleElements(t, page, "#success")
successDivText, err := page.First("#success").Text()
require.NoError(t, err)
require.Contains(t, successDivText, "Login succeeded")
require.Contains(t, successDivText, "You have successfully logged in. You may now close this tab.")
formpostExpectTitle(t, page, "Login succeeded")
formpostExpectFavicon(t, page, "✅")
}
// formpostExpectManualState asserts that the page is in the "manual" state and returns the auth code.
func formpostExpectManualState(t *testing.T, page *agouti.Page) string {
t.Logf("expecting to see manual message become visible...")
browsertest.WaitForVisibleElements(t, page, "#manual")
manualDivText, err := page.First("#manual").Text()
require.NoError(t, err)
require.Contains(t, manualDivText, "Finish your login")
require.Contains(t, manualDivText, "To finish logging in, paste this authorization code into your command-line session:")
formpostExpectTitle(t, page, "Finish your login")
formpostExpectFavicon(t, page, "⌛")
// Click the copy button and expect that the code is copied to the clipboard. Unfortunately,
// headless Chrome does not have a real clipboard we can check, so we rely on checking a
// console.log() statement that happens at the same time.
t.Logf("clicking the 'copy' button and expecting the clipboard event to fire...")
require.NoError(t, page.First("#manual-copy-button").Click())
var authCode string
consoleLogPattern := regexp.MustCompile(`code (.+) to clipboard`)
testlib.RequireEventually(t, func(requireEventually *require.Assertions) {
logs, err := page.ReadNewLogs("browser")
requireEventually.NoError(err)
for _, log := range logs {
if match := consoleLogPattern.FindStringSubmatch(log.Message); match != nil {
authCode = match[1]
return
}
}
requireEventually.FailNow("expected console log was not found")
}, 3*time.Second, 100*time.Millisecond)
return authCode
}

View File

@ -479,6 +479,7 @@ func requireWellKnownEndpointIsWorking(t *testing.T, supervisorScheme, superviso
"jwks_uri": "%s/jwks.json", "jwks_uri": "%s/jwks.json",
"scopes_supported": ["openid", "offline"], "scopes_supported": ["openid", "offline"],
"response_types_supported": ["code"], "response_types_supported": ["code"],
"response_modes_supported": ["query", "form_post"],
"claims_supported": ["groups"], "claims_supported": ["groups"],
"discovery.supervisor.pinniped.dev/v1alpha1": {"pinniped_identity_providers_endpoint": "%s/v1alpha1/pinniped_identity_providers"}, "discovery.supervisor.pinniped.dev/v1alpha1": {"pinniped_identity_providers_endpoint": "%s/v1alpha1/pinniped_identity_providers"},
"subject_types_supported": ["public"], "subject_types_supported": ["public"],

View File

@ -65,6 +65,7 @@ func RequireEventuallyf(
msg string, msg string,
args ...interface{}, args ...interface{},
) { ) {
t.Helper()
RequireEventually(t, f, waitFor, tick, fmt.Sprintf(msg, args...)) RequireEventually(t, f, waitFor, tick, fmt.Sprintf(msg, args...))
} }

View File

@ -27,6 +27,10 @@ func Open(t *testing.T) *agouti.Page {
t.Logf("opening browser driver") t.Logf("opening browser driver")
env := testlib.IntegrationEnv(t) env := testlib.IntegrationEnv(t)
caps := agouti.NewCapabilities() caps := agouti.NewCapabilities()
// Capture console.log(), not just console.error().
caps["loggingPrefs"] = map[string]string{"browser": "INFO"}
if env.Proxy != "" { if env.Proxy != "" {
t.Logf("configuring Chrome to use proxy %q", env.Proxy) t.Logf("configuring Chrome to use proxy %q", env.Proxy)
caps = caps.Proxy(agouti.ProxyConfig{ caps = caps.Proxy(agouti.ProxyConfig{

View File

@ -311,11 +311,15 @@ func CreateTestFederationDomain(ctx context.Context, t *testing.T, issuer string
return federationDomain return federationDomain
} }
func RandHex(t *testing.T, numBytes int) string { func RandBytes(t *testing.T, numBytes int) []byte {
buf := make([]byte, numBytes) buf := make([]byte, numBytes)
_, err := io.ReadFull(rand.Reader, buf) _, err := io.ReadFull(rand.Reader, buf)
require.NoError(t, err) require.NoError(t, err)
return hex.EncodeToString(buf) return buf
}
func RandHex(t *testing.T, numBytes int) string {
return hex.EncodeToString(RandBytes(t, numBytes))
} }
func CreateTestSecret(t *testing.T, namespace string, baseName string, secretType corev1.SecretType, stringData map[string]string) *corev1.Secret { func CreateTestSecret(t *testing.T, namespace string, baseName string, secretType corev1.SecretType, stringData map[string]string) *corev1.Secret {