diff --git a/cmd/pinniped/cmd/kubeconfig.go b/cmd/pinniped/cmd/kubeconfig.go index 32d724a2..54042b34 100644 --- a/cmd/pinniped/cmd/kubeconfig.go +++ b/cmd/pinniped/cmd/kubeconfig.go @@ -61,6 +61,7 @@ type getKubeconfigOIDCParams struct { listenPort uint16 scopes []string skipBrowser bool + skipListen bool sessionCachePath string debugSessionCache bool caBundle caBundleFlag @@ -146,6 +147,7 @@ func kubeconfigCommand(deps kubeconfigDeps) *cobra.Command { f.Uint16Var(&flags.oidc.listenPort, "oidc-listen-port", 0, "TCP port for localhost listener (authorization code flow only)") f.StringSliceVar(&flags.oidc.scopes, "oidc-scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "pinniped:request-audience"}, "OpenID Connect scopes to request during login") f.BoolVar(&flags.oidc.skipBrowser, "oidc-skip-browser", false, "During OpenID Connect login, skip opening the browser (just print the URL)") + f.BoolVar(&flags.oidc.skipListen, "oidc-skip-listen", false, "During OpenID Connect login, skip starting a localhost callback listener (manual copy/paste flow only)") f.StringVar(&flags.oidc.sessionCachePath, "oidc-session-cache", "", "Path to OpenID Connect session cache file") f.Var(&flags.oidc.caBundle, "oidc-ca-bundle", "Path to TLS certificate authority bundle (PEM format, optional, can be repeated)") f.BoolVar(&flags.oidc.debugSessionCache, "oidc-debug-session-cache", false, "Print debug logs related to the OpenID Connect session cache") @@ -161,6 +163,9 @@ func kubeconfigCommand(deps kubeconfigDeps) *cobra.Command { f.StringVar(&flags.credentialCachePath, "credential-cache", "", "Path to cluster-specific credentials cache") mustMarkHidden(cmd, "oidc-debug-session-cache") + // --oidc-skip-listen is mainly needed for testing. We'll leave it hidden until we have a non-testing use case. + mustMarkHidden(cmd, "oidc-skip-listen") + mustMarkDeprecated(cmd, "concierge-namespace", "not needed anymore") mustMarkHidden(cmd, "concierge-namespace") @@ -317,6 +322,9 @@ func newExecConfig(deps kubeconfigDeps, flags getKubeconfigParams) (*clientcmdap if flags.oidc.skipBrowser { execConfig.Args = append(execConfig.Args, "--skip-browser") } + if flags.oidc.skipListen { + execConfig.Args = append(execConfig.Args, "--skip-listen") + } if flags.oidc.listenPort != 0 { execConfig.Args = append(execConfig.Args, "--listen-port="+strconv.Itoa(int(flags.oidc.listenPort))) } diff --git a/cmd/pinniped/cmd/kubeconfig_test.go b/cmd/pinniped/cmd/kubeconfig_test.go index 2853b0e4..cb9f7b83 100644 --- a/cmd/pinniped/cmd/kubeconfig_test.go +++ b/cmd/pinniped/cmd/kubeconfig_test.go @@ -1352,6 +1352,7 @@ func TestGetKubeconfig(t *testing.T) { "--concierge-ca-bundle", testConciergeCABundlePath, "--oidc-issuer", issuerURL, "--oidc-skip-browser", + "--oidc-skip-listen", "--oidc-listen-port", "1234", "--oidc-ca-bundle", f.Name(), "--oidc-session-cache", "/path/to/cache/dir/sessions.yaml", @@ -1405,6 +1406,7 @@ func TestGetKubeconfig(t *testing.T) { - --client-id=pinniped-cli - --scopes=offline_access,openid,pinniped:request-audience - --skip-browser + - --skip-listen - --listen-port=1234 - --ca-bundle-data=%s - --session-cache=/path/to/cache/dir/sessions.yaml diff --git a/cmd/pinniped/cmd/login_oidc.go b/cmd/pinniped/cmd/login_oidc.go index 83542c01..9141d26a 100644 --- a/cmd/pinniped/cmd/login_oidc.go +++ b/cmd/pinniped/cmd/login_oidc.go @@ -59,6 +59,7 @@ type oidcLoginFlags struct { listenPort uint16 scopes []string skipBrowser bool + skipListen bool sessionCachePath string caBundlePaths []string caBundleData []string @@ -91,6 +92,7 @@ func oidcLoginCommand(deps oidcLoginCommandDeps) *cobra.Command { cmd.Flags().Uint16Var(&flags.listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only)") cmd.Flags().StringSliceVar(&flags.scopes, "scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "pinniped:request-audience"}, "OIDC scopes to request during login") cmd.Flags().BoolVar(&flags.skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL)") + cmd.Flags().BoolVar(&flags.skipListen, "skip-listen", false, "Skip starting a localhost callback listener (manual copy/paste flow only)") cmd.Flags().StringVar(&flags.sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file") cmd.Flags().StringSliceVar(&flags.caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated)") cmd.Flags().StringSliceVar(&flags.caBundleData, "ca-bundle-data", nil, "Base64 encoded TLS certificate authority bundle (base64 encoded PEM format, optional, can be repeated)") @@ -107,6 +109,8 @@ func oidcLoginCommand(deps oidcLoginCommandDeps) *cobra.Command { cmd.Flags().StringVar(&flags.upstreamIdentityProviderName, "upstream-identity-provider-name", "", "The name of the upstream identity provider used during login with a Supervisor") cmd.Flags().StringVar(&flags.upstreamIdentityProviderType, "upstream-identity-provider-type", "oidc", "The type of the upstream identity provider used during login with a Supervisor (e.g. 'oidc', 'ldap')") + // --skip-listen is mainly needed for testing. We'll leave it hidden until we have a non-testing use case. + mustMarkHidden(cmd, "skip-listen") mustMarkHidden(cmd, "debug-session-cache") mustMarkRequired(cmd, "issuer") cmd.RunE = func(cmd *cobra.Command, args []string) error { return runOIDCLogin(cmd, deps, flags) } @@ -182,12 +186,14 @@ func runOIDCLogin(cmd *cobra.Command, deps oidcLoginCommandDeps, flags oidcLogin } } - // --skip-browser replaces the default "browser open" function with one that prints to stderr. + // --skip-browser skips opening the browser. if flags.skipBrowser { - opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error { - cmd.PrintErr("Please log in: ", url, "\n") - return nil - })) + opts = append(opts, oidcclient.WithSkipBrowserOpen()) + } + + // --skip-listen skips starting the localhost callback listener. + if flags.skipListen { + opts = append(opts, oidcclient.WithSkipListen()) } if len(flags.caBundlePaths) > 0 || len(flags.caBundleData) > 0 { diff --git a/cmd/pinniped/cmd/login_oidc_test.go b/cmd/pinniped/cmd/login_oidc_test.go index b31d7021..055dcec6 100644 --- a/cmd/pinniped/cmd/login_oidc_test.go +++ b/cmd/pinniped/cmd/login_oidc_test.go @@ -226,6 +226,7 @@ func TestLoginOIDCCommand(t *testing.T) { "--client-id", "test-client-id", "--issuer", "test-issuer", "--skip-browser", + "--skip-listen", "--listen-port", "1234", "--debug-session-cache", "--request-audience", "cluster-1234", @@ -242,7 +243,7 @@ func TestLoginOIDCCommand(t *testing.T) { "--upstream-identity-provider-type", "ldap", }, env: map[string]string{"PINNIPED_DEBUG": "true"}, - wantOptionsCount: 10, + wantOptionsCount: 11, wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"token":"exchanged-token"}}` + "\n", wantLogs: []string{ "\"level\"=0 \"msg\"=\"Pinniped login: Performing OIDC login\" \"client id\"=\"test-client-id\" \"issuer\"=\"test-issuer\"", diff --git a/go.mod b/go.mod index 2b1522e8..3a224243 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.pinniped.dev -go 1.14 +go 1.16 require ( github.com/MakeNowJust/heredoc/v2 v2.0.1 @@ -26,6 +26,7 @@ require ( github.com/spf13/cobra v1.2.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.0 + github.com/tdewolff/minify/v2 v2.9.18 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a golang.org/x/net v0.0.0-20210520170846-37e1c6afe023 golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602 diff --git a/go.sum b/go.sum index a57507b9..eb6dfbd2 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,7 @@ github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cheekybits/is v0.0.0-20150225183255-68e9c0620927/go.mod h1:h/aW8ynjgkuj+NQRlZcDbAbM1ORAbXjXX77sX7T289U= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -840,6 +841,7 @@ github.com/markbates/safe v1.0.0/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kN github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/markbates/sigtx v1.0.0/go.mod h1:QF1Hv6Ic6Ca6W+T+DL0Y/ypborFKyvUY9HmuCD4VeTc= github.com/markbates/willie v1.0.9/go.mod h1:fsrFVWl91+gXpx/6dv715j7i11fYPfZ9ZGfH0DQzY7w= +github.com/matryer/try v0.0.0-20161228173917-9ac251b645a2/go.mod h1:0KeJpeMD6o+O4hW7qJOT7vyQPKrWmj26uf5wMc/IiIs= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= @@ -1150,6 +1152,12 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/subosito/gotenv v1.1.1/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= +github.com/tdewolff/minify/v2 v2.9.18 h1:j5Is0sOGp4cxm0o3HgvHCWCvTtmKnfB0qv0FCRbmgZY= +github.com/tdewolff/minify/v2 v2.9.18/go.mod h1:0y0mXZnisZm8HcgQvAV0btxa1IgecGam90zMuHqEZuc= +github.com/tdewolff/parse/v2 v2.5.18 h1:d67Ql/Pe36JcJZ7J2MY8upx6iTxbxGS9lzwyFGtMmd0= +github.com/tdewolff/parse/v2 v2.5.18/go.mod h1:WzaJpRSbwq++EIQHYIRTpbYKNA3gn9it1Ik++q4zyho= +github.com/tdewolff/test v1.0.6 h1:76mzYJQ83Op284kMT+63iCNCI7NEERsIN8dLM+RiKr4= +github.com/tdewolff/test v1.0.6/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE= github.com/tidwall/gjson v1.3.2/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/gjson v1.6.8/go.mod h1:zeFuBCIqD4sN/gmqBzZ4j7Jd6UcA2Fc56x7QFsv+8fI= github.com/tidwall/gjson v1.7.1/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= diff --git a/internal/httputil/securityheader/securityheader.go b/internal/httputil/securityheader/securityheader.go index 2bb3af12..e95edd0b 100644 --- a/internal/httputil/securityheader/securityheader.go +++ b/internal/httputil/securityheader/securityheader.go @@ -1,16 +1,22 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 // Package securityheader implements an HTTP middleware for setting security-related response headers. package securityheader -import "net/http" +import ( + "net/http" +) // Wrap the provided http.Handler so it sets appropriate security-related response headers. func Wrap(wrapped http.Handler) http.Handler { + return WrapWithCustomCSP(wrapped, "default-src 'none'; frame-ancestors 'none'") +} + +func WrapWithCustomCSP(wrapped http.Handler, cspHeader string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.Header() - h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") + h.Set("Content-Security-Policy", cspHeader) h.Set("X-Frame-Options", "DENY") h.Set("X-XSS-Protection", "1; mode=block") h.Set("X-Content-Type-Options", "nosniff") diff --git a/internal/httputil/securityheader/securityheader_test.go b/internal/httputil/securityheader/securityheader_test.go index 7ee7331f..639c495c 100644 --- a/internal/httputil/securityheader/securityheader_test.go +++ b/internal/httputil/securityheader/securityheader_test.go @@ -16,40 +16,71 @@ import ( ) func TestWrap(t *testing.T) { - testServer := httptest.NewServer(Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Test-Header", "test value") - _, _ = w.Write([]byte("hello world")) - }))) - t.Cleanup(testServer.Close) + for _, tt := range []struct { + name string + wrapFunc func(http.Handler) http.Handler + expectHeaders http.Header + }{ + { + name: "wrap", + wrapFunc: Wrap, + expectHeaders: http.Header{ + "X-Test-Header": []string{"test value"}, + "Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + "Referrer-Policy": []string{"no-referrer"}, + "X-Content-Type-Options": []string{"nosniff"}, + "X-Frame-Options": []string{"DENY"}, + "X-Xss-Protection": []string{"1; mode=block"}, + "X-Dns-Prefetch-Control": []string{"off"}, + "Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"}, + "Pragma": []string{"no-cache"}, + "Expires": []string{"0"}, + }, + }, + { + name: "custom CSP", + wrapFunc: func(h http.Handler) http.Handler { return WrapWithCustomCSP(h, "my-custom-csp-header") }, + expectHeaders: http.Header{ + "X-Test-Header": []string{"test value"}, + "Content-Security-Policy": []string{"my-custom-csp-header"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + "Referrer-Policy": []string{"no-referrer"}, + "X-Content-Type-Options": []string{"nosniff"}, + "X-Frame-Options": []string{"DENY"}, + "X-Xss-Protection": []string{"1; mode=block"}, + "X-Dns-Prefetch-Control": []string{"off"}, + "Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"}, + "Pragma": []string{"no-cache"}, + "Expires": []string{"0"}, + }, + }, + } { + tt := tt + t.Run(tt.name, func(t *testing.T) { + testServer := httptest.NewServer(tt.wrapFunc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test-Header", "test value") + _, _ = w.Write([]byte("hello world")) + }))) + t.Cleanup(testServer.Close) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil) - require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - require.Equal(t, http.StatusOK, resp.StatusCode) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) - respBody, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - require.Equal(t, "hello world", string(respBody)) + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "hello world", string(respBody)) - expected := http.Header{ - "X-Test-Header": []string{"test value"}, - "Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"}, - "Content-Type": []string{"text/plain; charset=utf-8"}, - "Referrer-Policy": []string{"no-referrer"}, - "X-Content-Type-Options": []string{"nosniff"}, - "X-Frame-Options": []string{"DENY"}, - "X-Xss-Protection": []string{"1; mode=block"}, - "X-Dns-Prefetch-Control": []string{"off"}, - "Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"}, - "Pragma": []string{"no-cache"}, - "Expires": []string{"0"}, - } - for key, values := range expected { - assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key) + for key, values := range tt.expectHeaders { + assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key) + } + }) } } diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 9c92301e..e6917954 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -1156,7 +1156,7 @@ func TestAuthorizationEndpoint(t *testing.T) { require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords) case test.wantRedirectLocationRegexp != "": require.Len(t, rsp.Header().Values("Location"), 1) - oidctestutil.RequireAuthcodeRedirectLocation( + oidctestutil.RequireAuthCodeRegexpMatch( t, rsp.Header().Get("Location"), test.wantRedirectLocationRegexp, diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index d585c962..8b9ab93e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -18,6 +18,7 @@ import ( "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidc/provider/formposthtml" "go.pinniped.dev/internal/plog" ) @@ -35,7 +36,7 @@ func NewHandler( stateDecoder, cookieDecoder oidc.Decoder, redirectURI string, ) http.Handler { - return securityheader.Wrap(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + handler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) if err != nil { return err @@ -97,7 +98,8 @@ func NewHandler( oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder) return nil - })) + }) + return securityheader.WrapWithCustomCSP(handler, formposthtml.ContentSecurityPolicy()) } func authcode(r *http.Request) string { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 583ee943..23912944 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -122,6 +122,7 @@ func TestCallbackEndpoint(t *testing.T) { wantContentType string wantBody string wantRedirectLocationRegexp string + wantBodyFormResponseRegexp string wantDownstreamGrantedScopes []string wantDownstreamIDTokenSubject string wantDownstreamIDTokenUsername string @@ -133,6 +134,32 @@ func TestCallbackEndpoint(t *testing.T) { wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs }{ + { + name: "GET with good state and cookie and successful upstream token exchange with response_mode=form_post returns 200 with HTML+JS form", + idp: happyUpstream().Build(), + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam().WithAuthorizeRequestParams( + shallowCopyAndModifyQuery( + happyDownstreamRequestParamsQuery, + map[string]string{"response_mode": "form_post"}, + ).Encode(), + ).Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: "text/html;charset=UTF-8", + wantBodyFormResponseRegexp: `(.+)`, + wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamGrantedScopes: happyDownstreamScopesGranted, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, { name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", idp: happyUpstream().Build(), @@ -666,15 +693,40 @@ func TestCallbackEndpoint(t *testing.T) { require.Equal(t, test.wantStatus, rsp.Code) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) - if test.wantBody != "" { + switch { + // If we want a specific static response body, assert that. + case test.wantBody != "": require.Equal(t, test.wantBody, rsp.Body.String()) - } else { + + // Else if we want a body that contains a regex-matched auth code, assert that (for "response_mode=form_post"). + case test.wantBodyFormResponseRegexp != "": + oidctestutil.RequireAuthCodeRegexpMatch( + t, + rsp.Body.String(), + test.wantBodyFormResponseRegexp, + client, + secrets, + oauthStore, + test.wantDownstreamGrantedScopes, + test.wantDownstreamIDTokenSubject, + test.wantDownstreamIDTokenUsername, + test.wantDownstreamIDTokenGroups, + test.wantDownstreamRequestedScopes, + test.wantDownstreamPKCEChallenge, + test.wantDownstreamPKCEChallengeMethod, + test.wantDownstreamNonce, + downstreamClientID, + downstreamRedirectURI, + ) + + // Otherwise, expect an empty response body. + default: require.Empty(t, rsp.Body.String()) } if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test require.Len(t, rsp.Header().Values("Location"), 1) - oidctestutil.RequireAuthcodeRedirectLocation( + oidctestutil.RequireAuthCodeRegexpMatch( t, rsp.Header().Get("Location"), test.wantRedirectLocationRegexp, diff --git a/internal/oidc/clientregistry/clientregistry.go b/internal/oidc/clientregistry/clientregistry.go index f60cc07d..c01caa7d 100644 --- a/internal/oidc/clientregistry/clientregistry.go +++ b/internal/oidc/clientregistry/clientregistry.go @@ -18,10 +18,16 @@ type Client struct { fosite.DefaultOpenIDConnectClient } -// It implements both the base and OIDC client interfaces of Fosite. +func (c Client) GetResponseModes() []fosite.ResponseModeType { + // For now, all Pinniped clients always support "" (unspecified), "query", and "form_post" response modes. + return []fosite.ResponseModeType{fosite.ResponseModeDefault, fosite.ResponseModeQuery, fosite.ResponseModeFormPost} +} + +// It implements both the base, OIDC, and response_mode client interfaces of Fosite. var ( _ fosite.Client = (*Client)(nil) _ fosite.OpenIDConnectClient = (*Client)(nil) + _ fosite.ResponseModeClient = (*Client)(nil) ) // StaticClientManager is a fosite.ClientManager with statically-defined clients. diff --git a/internal/oidc/clientregistry/clientregistry_test.go b/internal/oidc/clientregistry/clientregistry_test.go index 0da67fcc..5062f629 100644 --- a/internal/oidc/clientregistry/clientregistry_test.go +++ b/internal/oidc/clientregistry/clientregistry_test.go @@ -59,6 +59,7 @@ func TestPinnipedCLI(t *testing.T) { require.Equal(t, "", c.GetRequestObjectSigningAlgorithm()) require.Equal(t, "none", c.GetTokenEndpointAuthMethod()) require.Equal(t, "RS256", c.GetTokenEndpointAuthSigningAlgorithm()) + require.Equal(t, []fosite.ResponseModeType{"", "query", "form_post"}, c.GetResponseModes()) marshaled, err := json.Marshal(c) require.NoError(t, err) diff --git a/internal/oidc/discovery/discovery_handler.go b/internal/oidc/discovery/discovery_handler.go index e472c012..008808b6 100644 --- a/internal/oidc/discovery/discovery_handler.go +++ b/internal/oidc/discovery/discovery_handler.go @@ -25,6 +25,7 @@ type Metadata struct { JWKSURI string `json:"jwks_uri"` ResponseTypesSupported []string `json:"response_types_supported"` + ResponseModesSupported []string `json:"response_modes_supported"` SubjectTypesSupported []string `json:"subject_types_supported"` IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` @@ -63,6 +64,7 @@ func NewHandler(issuerURL string) http.Handler { JWKSURI: issuerURL + oidc.JWKSEndpointPath, SupervisorDiscovery: SupervisorDiscoveryMetadataV1Alpha1{PinnipedIDPsEndpoint: issuerURL + oidc.PinnipedIDPsPathV1Alpha1}, ResponseTypesSupported: []string{"code"}, + ResponseModesSupported: []string{"query", "form_post"}, SubjectTypesSupported: []string{"public"}, IDTokenSigningAlgValuesSupported: []string{"ES256"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, diff --git a/internal/oidc/discovery/discovery_handler_test.go b/internal/oidc/discovery/discovery_handler_test.go index b3c70b35..b1707f77 100644 --- a/internal/oidc/discovery/discovery_handler_test.go +++ b/internal/oidc/discovery/discovery_handler_test.go @@ -43,6 +43,7 @@ func TestDiscovery(t *testing.T) { PinnipedIDPsEndpoint: "https://some-issuer.com/some/path/v1alpha1/pinniped_identity_providers", }, ResponseTypesSupported: []string{"code"}, + ResponseModesSupported: []string{"query", "form_post"}, SubjectTypesSupported: []string{"public"}, IDTokenSigningAlgValuesSupported: []string{"ES256"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index d92a0f1b..d29979c8 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -14,6 +14,7 @@ import ( "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidc/provider/formposthtml" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/pkce" ) @@ -217,7 +218,7 @@ func FositeOauth2Helper( MinParameterEntropy: fosite.MinParameterEntropy, } - return compose.Compose( + provider := compose.Compose( oauthConfig, oauthStore, &compose.CommonStrategy{ @@ -233,6 +234,8 @@ func FositeOauth2Helper( compose.OAuth2PKCEFactory, TokenExchangeFactory, ) + provider.(*fosite.Fosite).FormPostHTMLTemplate = formposthtml.Template() + return provider } // FositeErrorForLog generates a list of information about the provided Fosite error that can be diff --git a/internal/oidc/provider/formposthtml/form_post.css b/internal/oidc/provider/formposthtml/form_post.css new file mode 100644 index 00000000..c65c2fc7 --- /dev/null +++ b/internal/oidc/provider/formposthtml/form_post.css @@ -0,0 +1,87 @@ +/* Copyright 2021 the Pinniped contributors. All Rights Reserved. */ +/* SPDX-License-Identifier: Apache-2.0 */ + +body { + font-family: "Metropolis-Light", Helvetica, sans-serif; +} + +h1 { + font-size: 20px; +} + +.state { + position: absolute; + top: 100px; + left: 50%; + width: 400px; + height: 80px; + margin-top: -40px; + margin-left: -200px; + font-size: 14px; + line-height: 24px; +} + +button { + margin: -10px; + padding: 10px; + text-align: left; + width: 100%; + display: inline; + border: none; + background: none; + cursor: pointer; + transition: all .1s; +} + +button:hover { + background-color: #eee; + transform: scale(1.01); +} + +button:active { + background-color: #ddd; + transform: scale(.99); +} + +code { + word-wrap: break-word; + hyphens: auto; + hyphenate-character: ''; + font-size: 12px; + font-family: monospace; + color: #333; +} + +.copy-icon { + float: left; + width: 36px; + height: 36px; + padding-top: 2px; + padding-right: 10px; + background-size: contain; + background-repeat: no-repeat; + /* + This is the "copy-to-clipboard-line.svg" icon from Clarity (https://clarity.design/): + https://github.com/vmware/clarity-assets/blob/master/icons/essential/copy-to-clipboard-line.svg + */ + background-image: url("data:image/svg+xml,%3Csvg version='1.1' width='36' height='36' viewBox='0 0 36 36' preserveAspectRatio='xMidYMid meet' xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink'%3E%3Ctitle%3Ecopy-to-clipboard-line%3C/title%3E%3Cpath d='M22.6,4H21.55a3.89,3.89,0,0,0-7.31,0H13.4A2.41,2.41,0,0,0,11,6.4V10H25V6.4A2.41,2.41,0,0,0,22.6,4ZM23,8H13V6.25A.25.25,0,0,1,13.25,6h2.69l.12-1.11A1.24,1.24,0,0,1,16.61,4a2,2,0,0,1,3.15,1.18l.09.84h2.9a.25.25,0,0,1,.25.25Z' class='clr-i-outline clr-i-outline-path-1'%3E%3C/path%3E%3Cpath d='M33.25,18.06H21.33l2.84-2.83a1,1,0,1,0-1.42-1.42L17.5,19.06l5.25,5.25a1,1,0,0,0,.71.29,1,1,0,0,0,.71-1.7l-2.84-2.84H33.25a1,1,0,0,0,0-2Z' class='clr-i-outline clr-i-outline-path-2'%3E%3C/path%3E%3Cpath d='M29,16h2V6.68A1.66,1.66,0,0,0,29.35,5H27.08V7H29Z' class='clr-i-outline clr-i-outline-path-3'%3E%3C/path%3E%3Cpath d='M29,31H7V7H9V5H6.64A1.66,1.66,0,0,0,5,6.67V31.32A1.66,1.66,0,0,0,6.65,33H29.36A1.66,1.66,0,0,0,31,31.33V22.06H29Z' class='clr-i-outline clr-i-outline-path-4'%3E%3C/path%3E%3Crect x='0' y='0' width='36' height='36' fill-opacity='0'/%3E%3C/svg%3E"); +} + +@keyframes loader { + to { + transform: rotate(360deg); + } +} + +#loading { + content: ''; + box-sizing: border-box; + width: 80px; + height: 80px; + margin-top: -40px; + margin-left: -40px; + border-radius: 50%; + border: 2px solid #fff; + border-top-color: #1b3951; + animation: loader .6s linear infinite; +} diff --git a/internal/oidc/provider/formposthtml/form_post.gohtml b/internal/oidc/provider/formposthtml/form_post.gohtml new file mode 100644 index 00000000..92be18d2 --- /dev/null +++ b/internal/oidc/provider/formposthtml/form_post.gohtml @@ -0,0 +1,34 @@ + + + + + + + + + + +
+ + +
+ + + + + diff --git a/internal/oidc/provider/formposthtml/form_post.js b/internal/oidc/provider/formposthtml/form_post.js new file mode 100644 index 00000000..4c0eb7df --- /dev/null +++ b/internal/oidc/provider/formposthtml/form_post.js @@ -0,0 +1,54 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +window.onload = () => { + const transitionToState = (id) => { + // Hide all the other ".state"
s. + Array.from(document.querySelectorAll('.state')).forEach(e => e.hidden = true); + + // Unhide the current state
. + const currentDiv = document.getElementById(id) + currentDiv.hidden = false; + + // Set the window title. + document.title = currentDiv.dataset.title; + + // Set the favicon using inline SVG (does not work on Safari). + document.getElementById('favicon').setAttribute( + 'href', + 'data:image/svg+xml,' + + currentDiv.dataset.favicon + + '' + ); + } + + // At load, show the spinner, hide the other divs, set the favicon, and + // replace the URL path with './' so the upstream auth code disappears. + transitionToState('loading'); + window.history.replaceState(null, '', './'); + + // When the copy button is clicked, copy to the clipboard. + document.getElementById('manual-copy-button').onclick = () => { + const code = document.getElementById('manual-copy-button').innerText; + navigator.clipboard.writeText(code) + .then(() => console.info('copied authorization code ' + code + ' to clipboard')) + .catch(e => console.error('failed to copy code ' + code + ' to clipboard: ' + e)); + }; + + // Set a timeout to transition to the "manual" state if nothing succeeds within 2s. + const timeout = setTimeout(() => transitionToState('manual'), 2000); + + // Try to submit the POST callback, handling the success and error cases. + const responseParams = document.forms[0].elements; + fetch( + responseParams['redirect_uri'].value, + { + method: 'POST', + mode: 'no-cors', + headers: {'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8'}, + body: responseParams['encoded_params'].value, + }) + .then(() => clearTimeout(timeout)) + .then(() => transitionToState('success')) + .catch(() => transitionToState('manual')); +}; diff --git a/internal/oidc/provider/formposthtml/formposthtml.go b/internal/oidc/provider/formposthtml/formposthtml.go new file mode 100644 index 00000000..4eeebf74 --- /dev/null +++ b/internal/oidc/provider/formposthtml/formposthtml.go @@ -0,0 +1,65 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Package formposthtml defines HTML templates used by the Supervisor. +//nolint: gochecknoglobals // This package uses globals to ensure that all parsing and minifying happens at init. +package formposthtml + +import ( + "crypto/sha256" + _ "embed" // Needed to trigger //go:embed directives below. + "encoding/base64" + "html/template" + "strings" + + "github.com/tdewolff/minify/v2/minify" +) + +var ( + //go:embed form_post.css + rawCSS string + minifiedCSS = mustMinify(minify.CSS(rawCSS)) + + //go:embed form_post.js + rawJS string + minifiedJS = mustMinify(minify.JS(rawJS)) + + //go:embed form_post.gohtml + rawHTMLTemplate string +) + +// Parse the Go templated HTML and inject functions providing the minified inline CSS and JS. +var parsedHTMLTemplate = template.Must(template.New("form_post.gohtml").Funcs(template.FuncMap{ + "minifiedCSS": func() template.CSS { return template.CSS(minifiedCSS) }, + "minifiedJS": func() template.JS { return template.JS(minifiedJS) }, //nolint:gosec // This is 100% static input, not attacker-controlled. +}).Parse(rawHTMLTemplate)) + +// Generate the CSP header value once since it's effectively constant: +var cspValue = strings.Join([]string{ + `default-src 'none'`, + `script-src '` + cspHash(minifiedJS) + `'`, + `style-src '` + cspHash(minifiedCSS) + `'`, + `img-src data:`, + `connect-src *`, + `frame-ancestors 'none'`, +}, "; ") + +func mustMinify(s string, err error) string { + if err != nil { + panic(err) + } + return s +} + +func cspHash(s string) string { + hashBytes := sha256.Sum256([]byte(s)) + return "sha256-" + base64.StdEncoding.EncodeToString(hashBytes[:]) +} + +// ContentSecurityPolicy returns the Content-Security-Policy header value to make the Template() operate correctly. +// +// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/default-src#:~:text=%27%3Chash-algorithm%3E-%3Cbase64-value%3E%27. +func ContentSecurityPolicy() string { return cspValue } + +// Template returns the html/template.Template for rendering the response_type=form_post response page. +func Template() *template.Template { return parsedHTMLTemplate } diff --git a/internal/oidc/provider/formposthtml/formposthtml_test.go b/internal/oidc/provider/formposthtml/formposthtml_test.go new file mode 100644 index 00000000..a8a1a929 --- /dev/null +++ b/internal/oidc/provider/formposthtml/formposthtml_test.go @@ -0,0 +1,101 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package formposthtml + +import ( + "bytes" + "fmt" + "net/url" + "testing" + + "github.com/ory/fosite" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/here" +) + +var ( + testRedirectURL = "http://127.0.0.1:12345/callback" + + testResponseParams = url.Values{ + "code": []string{"test-S629KHsCCBYV0PQ6FDSrn6iEXtVImQRBh7NCAk.JezyUSdCiSslYjtUmv7V5VAgiCz3ZkES9mYldg9GhqU"}, + "scope": []string{"openid offline_access pinniped:request-audience"}, + "state": []string{"01234567890123456789012345678901"}, + } + + testExpectedFormPostOutput = here.Doc(` + + + + + + + + + + +
+ + +
+ + + + + + `) + + // It's okay if this changes in the future, but this gives us a chance to eyeball the formatting. + // Our browser-based integration tests should find any incompatibilities. + testExpectedCSP = `default-src 'none'; ` + + `script-src 'sha256-U+tKnJ2oMSYKSxmSX3V2mPBN8xdr9JpampKAhbSo108='; ` + + `style-src 'sha256-TLAQE3UR2KpwP7AzMCE4iPDizh7zLPx9UXeK5ntuoRg='; ` + + `img-src data:; ` + + `connect-src *; ` + + `frame-ancestors 'none'` +) + +func TestTemplate(t *testing.T) { + // Use the Fosite helper to render the form, ensuring that the parameters all have the same names + types. + var buf bytes.Buffer + fosite.WriteAuthorizeFormPostResponse(testRedirectURL, testResponseParams, Template(), &buf) + + // Render again so we can confirm that there is no error returned (Fosite ignores any error). + var buf2 bytes.Buffer + require.NoError(t, Template().Execute(&buf2, struct { + RedirURL string + Parameters url.Values + }{ + RedirURL: testRedirectURL, + Parameters: testResponseParams, + })) + + require.Equal(t, buf.String(), buf2.String()) + require.Equal(t, testExpectedFormPostOutput, buf.String()) +} + +func TestContentSecurityPolicyHashes(t *testing.T) { + require.Equal(t, testExpectedCSP, ContentSecurityPolicy()) +} + +func TestHelpers(t *testing.T) { + // These are silly tests but it's easy to we might as well have them. + require.Equal(t, "test", mustMinify("test", nil)) + require.PanicsWithError(t, "some error", func() { mustMinify("", fmt.Errorf("some error")) }) + + // Example test vector from https://content-security-policy.com/hash/. + require.Equal(t, "sha256-RFWPLDbv2BY+rCkDzsE+0fr8ylGr2R2faWMhq4lfEQc=", cspHash("doSomething();")) +} diff --git a/internal/testutil/assertions.go b/internal/testutil/assertions.go index 54fc8563..9286bff1 100644 --- a/internal/testutil/assertions.go +++ b/internal/testutil/assertions.go @@ -1,4 +1,4 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package testutil @@ -55,7 +55,9 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret } func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) { - require.Equal(t, "default-src 'none'; frame-ancestors 'none'", response.Header().Get("Content-Security-Policy")) + // This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use. + require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'") + require.Equal(t, "DENY", response.Header().Get("X-Frame-Options")) require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection")) require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options")) diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index b8e7b0de..f690af52 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -235,10 +235,10 @@ func VerifyECDSAIDToken( return token } -func RequireAuthcodeRedirectLocation( +func RequireAuthCodeRegexpMatch( t *testing.T, - actualRedirectLocation string, - wantRedirectLocationRegexp string, + actualContent string, + wantRegexp string, kubeClient *fake.Clientset, secretsClient v1.SecretInterface, oauthStore fositestoragei.AllFositeStorage, @@ -256,9 +256,9 @@ func RequireAuthcodeRedirectLocation( t.Helper() // Assert that Location header matches regular expression. - regex := regexp.MustCompile(wantRedirectLocationRegexp) - submatches := regex.FindStringSubmatch(actualRedirectLocation) - require.Lenf(t, submatches, 2, "no regexp match in actualRedirectLocation: %q", actualRedirectLocation) + regex := regexp.MustCompile(wantRegexp) + submatches := regex.FindStringSubmatch(actualContent) + require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent) capturedAuthCode := submatches[1] // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index fdf34eff..36b688e2 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "mime" "net" "net/http" @@ -17,6 +18,7 @@ import ( "os" "sort" "strings" + "syscall" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -87,6 +89,7 @@ type handlerState struct { // Generated parameters of a login flow. provider *oidc.Provider oauth2Config *oauth2.Config + useFormPost bool state state.State nonce nonce.Nonce pkce pkce.Code @@ -96,10 +99,12 @@ type handlerState struct { generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error + listen func(string, string) (net.Listener, error) + isTTY func(int) bool getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) - promptForValue func(promptLabel string) (string, error) - promptForSecret func(promptLabel string) (string, error) + promptForValue func(ctx context.Context, promptLabel string) (string, error) + promptForSecret func(ctx context.Context, promptLabel string) (string, error) callbacks chan callbackResult } @@ -156,6 +161,9 @@ func WithScopes(scopes []string) Option { // WithBrowserOpen overrides the default "open browser" functionality with a custom callback. If not specified, // an implementation using https://github.com/pkg/browser will be used by default. +// +// Deprecated: this option will be removed in a future version of Pinniped. See the +// WithSkipBrowserOpen() option instead. func WithBrowserOpen(openURL func(url string) error) Option { return func(h *handlerState) error { h.openURL = openURL @@ -163,6 +171,23 @@ func WithBrowserOpen(openURL func(url string) error) Option { } } +// WithSkipBrowserOpen causes the login to only print the authorize URL, but skips attempting to +// open the user's default web browser. +func WithSkipBrowserOpen() Option { + return func(h *handlerState) error { + h.openURL = func(_ string) error { return nil } + return nil + } +} + +// WithSkipListen causes the login skip starting the localhost listener, forcing the manual copy/paste login flow. +func WithSkipListen() Option { + return func(h *handlerState) error { + h.listen = func(string, string) (net.Listener, error) { return nil, nil } + return nil + } +} + // SessionCacheKey contains the data used to select a valid session cache entry. type SessionCacheKey struct { Issuer string `json:"issuer"` @@ -250,6 +275,8 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er generateNonce: nonce.Generate, generatePKCE: pkce.Generate, openURL: browser.OpenURL, + listen: net.Listen, + isTTY: term.IsTerminal, getProvider: upstreamoidc.New, validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) { return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token) @@ -376,11 +403,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) { // and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error. func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { // Ask the user for their username and password. - username, err := h.promptForValue(defaultLDAPUsernamePrompt) + username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt) if err != nil { return nil, fmt.Errorf("error prompting for username: %w", err) } - password, err := h.promptForSecret(defaultLDAPPasswordPrompt) + password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt) if err != nil { return nil, fmt.Errorf("error prompting for password: %w", err) } @@ -475,30 +502,55 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) ( // Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint. // Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error. func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { - // Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number). - listener, err := net.Listen("tcp", h.listenAddr) + // Attempt to open a local TCP listener, logging but otherwise ignoring any error. + listener, err := h.listen("tcp", h.listenAddr) if err != nil { - return nil, fmt.Errorf("could not open callback listener: %w", err) + h.logger.V(debugLogLevel).Error(err, "could not open callback listener") + } + + // If the listener failed to start and stdin is not a TTY, then we have no hope of succeeding, + // since we won't be able to receive the web callback and we can't prompt for the manual auth code. + if listener == nil && !h.isTTY(syscall.Stdin) { + return nil, fmt.Errorf("login failed: must have either a localhost listener or stdin must be a TTY") + } + + // Update the OAuth2 redirect_uri to match the actual listener address (if there is one), or just use + // a fake ":0" port if there is no listener running. + redirectURI := url.URL{Scheme: "http", Path: h.callbackPath} + if listener == nil { + redirectURI.Host = "127.0.0.1:0" + } else { + redirectURI.Host = listener.Addr().String() + } + h.oauth2Config.RedirectURL = redirectURI.String() + + // If the server supports it, request response_mode=form_post. + authParams := *authorizeOptions + if h.useFormPost { + authParams = append(authParams, oauth2.SetAuthURLParam("response_mode", "form_post")) } - h.oauth2Config.RedirectURL = (&url.URL{ - Scheme: "http", - Host: listener.Addr().String(), - Path: h.callbackPath, - }).String() // Now that we have a redirect URL with the listener port, we can build the authorize URL. - authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), *authorizeOptions...) + authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...) - // Start a callback server in a background goroutine. - shutdown := h.serve(listener) - defer shutdown() - - // Open the authorize URL in the users browser. - if err := h.openURL(authorizeURL); err != nil { - return nil, fmt.Errorf("could not open browser: %w", err) + // If there is a listener running, start serving the callback handler in a background goroutine. + if listener != nil { + shutdown := h.serve(listener) + defer shutdown() } - // Wait for either the callback or a timeout. + // Open the authorize URL in the users browser, logging but otherwise ignoring any error. + if err := h.openURL(authorizeURL); err != nil { + h.logger.V(debugLogLevel).Error(err, "could not open browser") + } + + ctx, cancel := context.WithCancel(h.ctx) + defer cancel() + + // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). + h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + + // Wait for either the web callback, a pasted auth code, or a timeout. select { case <-h.ctx.Done(): return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err()) @@ -510,7 +562,37 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp } } -func promptForValue(promptLabel string) (string, error) { +func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) { + _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) + + // If stdin is not a TTY, print the URL but don't prompt for the manual paste, + // since we have no way of reading it. + if !h.isTTY(syscall.Stdin) { + return + } + + // If the server didn't support response_mode=form_post, don't bother prompting for the manual + // code because the user isn't going to have any easy way to manually copy it anyway. + if !h.useFormPost { + return + } + + // Launch the manual auth code prompt in a background goroutine, which will be cancelled + // if the parent context is cancelled (when the login succeeds or times out). + go func() { + code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ") + if err != nil { + h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} + return + } + + // When a code is pasted, redeem it for a token and return that result on the callbacks channel. + token, err := h.redeemAuthCode(ctx, code) + h.callbacks <- callbackResult{token: token, err: err} + }() +} + +func promptForValue(ctx context.Context, promptLabel string) (string, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { return "", errors.New("stdin is not connected to a terminal") } @@ -518,6 +600,15 @@ func promptForValue(promptLabel string) (string, error) { if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } + + // If the context is canceled, set the read deadline on stdin so the read immediately finishes. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + _ = os.Stdin.SetReadDeadline(time.Now()) + }() + text, err := bufio.NewReader(os.Stdin).ReadString('\n') if err != nil { return "", fmt.Errorf("could read input from stdin: %w", err) @@ -526,7 +617,7 @@ func promptForValue(promptLabel string) (string, error) { return text, nil } -func promptForSecret(promptLabel string) (string, error) { +func promptForSecret(ctx context.Context, promptLabel string) (string, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { return "", errors.New("stdin is not connected to a terminal") } @@ -534,17 +625,27 @@ func promptForSecret(promptLabel string) (string, error) { if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } - password, err := term.ReadPassword(0) + + // If the context is canceled, set the read deadline on stdin so the read immediately finishes. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + _ = os.Stdin.SetReadDeadline(time.Now()) + + // term.ReadPassword swallows the newline that was typed by the user, so to + // avoid the next line of output from happening on same line as the password + // prompt, we need to print a newline. + // + // Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next + // on stderr is formatted correctly. + _, _ = fmt.Fprint(os.Stderr, "\n") + }() + + password, err := term.ReadPassword(syscall.Stdin) if err != nil { return "", fmt.Errorf("could not read password: %w", err) } - // term.ReadPassword swallows the newline that was typed by the user, so to - // avoid the next line of output from happening on same line as the password - // prompt, we need to print a newline. - _, err = fmt.Fprint(os.Stderr, "\n") - if err != nil { - return "", fmt.Errorf("could not print newline to stderr: %w", err) - } return string(password), err } @@ -567,9 +668,27 @@ func (h *handlerState) initOIDCDiscovery() error { Endpoint: h.provider.Endpoint(), Scopes: h.scopes, } + + // Use response_mode=form_post if the provider supports it. + var discoveryClaims struct { + ResponseModesSupported []string `json:"response_modes_supported"` + } + if err := h.provider.Claims(&discoveryClaims); err != nil { + return fmt.Errorf("could not decode response_modes_supported in OIDC discovery from %q: %w", h.issuer, err) + } + h.useFormPost = stringSliceContains(discoveryClaims.ResponseModesSupported, "form_post") return nil } +func stringSliceContains(slice []string, s string) bool { + for _, item := range slice { + if item == s { + return true + } + } + return false +} + func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) { h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience) // Perform OIDC discovery. This may have already been performed if there was not a cached base token. @@ -664,13 +783,29 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req } }() - // Return HTTP 405 for anything that's not a GET. - if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") + var params url.Values + if h.useFormPost { + // Return HTTP 405 for anything that's not a POST. + if r.Method != http.MethodPost { + return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST") + } + + // Parse and pull the response parameters from a application/x-www-form-urlencoded request body. + if err := r.ParseForm(); err != nil { + return httperr.Wrap(http.StatusBadRequest, "invalid form", err) + } + params = r.Form + } else { + // Return HTTP 405 for anything that's not a GET. + if r.Method != http.MethodGet { + return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") + } + + // Pull response parameters from the URL query string. + params = r.URL.Query() } // Validate OAuth2 state and fail if it's incorrect (to block CSRF). - params := r.URL.Query() if err := h.state.Validate(params.Get("state")); err != nil { return httperr.New(http.StatusForbidden, "missing or invalid state parameter") } @@ -685,14 +820,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). - ExchangeAuthcodeAndValidateTokens( - r.Context(), - params.Get("code"), - h.pkce, - h.nonce, - h.oauth2Config.RedirectURL, - ) + token, err := h.redeemAuthCode(r.Context(), params.Get("code")) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } @@ -702,6 +830,17 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return nil } +func (h *handlerState) redeemAuthCode(ctx context.Context, code string) (*oidctypes.Token, error) { + return h.getProvider(h.oauth2Config, h.provider, h.httpClient). + ExchangeAuthcodeAndValidateTokens( + ctx, + code, + h.pkce, + h.nonce, + h.oauth2Config.RedirectURL, + ) +} + func (h *handlerState) serve(listener net.Listener) func() { mux := http.NewServeMux() mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index d85c01ac..358becd3 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -4,15 +4,18 @@ package oidcclient import ( + "bytes" "context" "encoding/json" "errors" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" "strings" + "syscall" "testing" "time" @@ -80,6 +83,22 @@ func TestLogin(t *testing.T) { // nolint:gocyclo })) t.Cleanup(errorServer.Close) + // Start a test server that returns discovery data with a broken response_modes_supported value. + brokenResponseModeMux := http.NewServeMux() + brokenResponseModeServer := httptest.NewServer(brokenResponseModeMux) + brokenResponseModeMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + type providerJSON struct { + Issuer string `json:"issuer"` + ResponseModesSupported string `json:"response_modes_supported"` // Wrong type (should be []string). + } + _ = json.NewEncoder(w).Encode(&providerJSON{ + Issuer: brokenResponseModeServer.URL, + ResponseModesSupported: "invalid", + }) + }) + t.Cleanup(brokenResponseModeServer.Close) + // Start a test server that returns discovery data with a broken token URL brokenTokenURLMux := http.NewServeMux() brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux) @@ -100,30 +119,29 @@ func TestLogin(t *testing.T) { // nolint:gocyclo }) t.Cleanup(brokenTokenURLServer.Close) - // Start a test server that returns a real discovery document and answers refresh requests. - providerMux := http.NewServeMux() - successServer := httptest.NewServer(providerMux) - t.Cleanup(successServer.Close) - providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "unexpected method", http.StatusMethodNotAllowed) - return + discoveryHandler := func(server *httptest.Server, responseModes []string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + return + } + w.Header().Set("content-type", "application/json") + _ = json.NewEncoder(w).Encode(&struct { + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + }{ + Issuer: server.URL, + AuthURL: server.URL + "/authorize", + TokenURL: server.URL + "/token", + JWKSURL: server.URL + "/keys", + ResponseModesSupported: responseModes, + }) } - w.Header().Set("content-type", "application/json") - type providerJSON struct { - Issuer string `json:"issuer"` - AuthURL string `json:"authorization_endpoint"` - TokenURL string `json:"token_endpoint"` - JWKSURL string `json:"jwks_uri"` - } - _ = json.NewEncoder(w).Encode(&providerJSON{ - Issuer: successServer.URL, - AuthURL: successServer.URL + "/authorize", - TokenURL: successServer.URL + "/token", - JWKSURL: successServer.URL + "/keys", - }) - }) - providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + } + tokenHandler := func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "unexpected method", http.StatusMethodNotAllowed) return @@ -204,7 +222,21 @@ func TestLogin(t *testing.T) { // nolint:gocyclo w.Header().Set("content-type", "application/json") require.NoError(t, json.NewEncoder(w).Encode(&response)) - }) + } + + // Start a test server that returns a real discovery document and answers refresh requests. + providerMux := http.NewServeMux() + successServer := httptest.NewServer(providerMux) + t.Cleanup(successServer.Close) + providerMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(successServer, nil)) + providerMux.HandleFunc("/token", tokenHandler) + + // Start a test server that returns a real discovery document and answers refresh requests, _and_ supports form_mode=post. + formPostProviderMux := http.NewServeMux() + formPostSuccessServer := httptest.NewServer(formPostProviderMux) + t.Cleanup(formPostSuccessServer.Close) + formPostProviderMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(formPostSuccessServer, []string{"query", "form_post"})) + formPostProviderMux.HandleFunc("/token", tokenHandler) defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam // Call the handler function from the test server to calculate the response. @@ -218,8 +250,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.generateState = func() (state.State, error) { return "test-state", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } - h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil } - h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil } + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } + h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil } cache := &mockSessionCache{t: t, getReturnsToken: nil} cacheKey := SessionCacheKey{ @@ -349,7 +381,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo wantToken: &testToken, }, { - name: "discovery failure", + name: "discovery failure due to 500 error", opt: func(t *testing.T) Option { return func(h *handlerState) error { return nil } }, @@ -357,6 +389,15 @@ func TestLogin(t *testing.T) { // nolint:gocyclo wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""}, wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL), }, + { + name: "discovery failure due to invalid response_modes_supported", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { return nil } + }, + issuer: brokenResponseModeServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenResponseModeServer.URL + "\""}, + wantErr: fmt.Sprintf("could not decode response_modes_supported in OIDC discovery from %q: json: cannot unmarshal string into Go struct field .response_modes_supported of type []string", brokenResponseModeServer.URL), + }, { name: "session cache hit with refreshable token", issuer: successServer.URL, @@ -451,38 +492,93 @@ func TestLogin(t *testing.T) { // nolint:gocyclo }) h.cache = cache - h.listenAddr = "invalid-listen-address" - + h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } + h.isTTY = func(int) bool { return false } return nil } }, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"", - "\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""}, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`, + `"level"=4 "msg"="Pinniped: Refreshing cached token."`, + `"msg"="could not open callback listener" "error"="some listen error"`, + }, // Expect this to fall through to the authorization code flow, so it fails here. - wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", + wantErr: "login failed: must have either a localhost listener or stdin must be a TTY", }, { - name: "listen failure", + name: "listen failure and non-tty stdin", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.listenAddr = "invalid-listen-address" + h.listen = func(net string, addr string) (net.Listener, error) { + assert.Equal(t, "tcp", net) + assert.Equal(t, "localhost:0", addr) + return nil, fmt.Errorf("some listen error") + } + h.isTTY = func(fd int) bool { + assert.Equal(t, fd, syscall.Stdin) + return false + } return nil } }, - issuer: successServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", + issuer: successServer.URL, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`, + `"msg"="could not open callback listener" "error"="some listen error"`, + }, + wantErr: "login failed: must have either a localhost listener or stdin must be a TTY", }, { - name: "browser open failure", + name: "listening disabled and manual prompt fails", opt: func(t *testing.T) Option { - return WithBrowserOpen(func(url string) error { - return fmt.Errorf("some browser open error") - }) + return func(h *handlerState) error { + require.NoError(t, WithSkipListen()(h)) + h.isTTY = func(fd int) bool { return true } + h.openURL = func(authorizeURL string) error { + parsed, err := url.Parse(authorizeURL) + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri")) + require.Equal(t, "form_post", parsed.Query().Get("response_mode")) + return fmt.Errorf("some browser open error") + } + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "", fmt.Errorf("some prompt error") + } + return nil + } }, - issuer: successServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "could not open browser: some browser open error", + issuer: formPostSuccessServer.URL, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, + `"msg"="could not open browser" "error"="some browser open error"`, + }, + wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", + }, + { + name: "listen success and manual prompt succeeds", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } + h.isTTY = func(fd int) bool { return true } + h.openURL = func(authorizeURL string) error { + parsed, err := url.Parse(authorizeURL) + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri")) + require.Equal(t, "form_post", parsed.Query().Get("response_mode")) + return nil + } + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "", fmt.Errorf("some prompt error") + } + return nil + } + }, + issuer: formPostSuccessServer.URL, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, + `"msg"="could not open callback listener" "error"="some listen error"`, + }, + wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", }, { name: "timeout waiting for callback", @@ -580,6 +676,68 @@ func TestLogin(t *testing.T) { // nolint:gocyclo wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantToken: &testToken, }, + { + name: "callback returns success with request_mode=form_post", + clientID: "test-client-id", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.generateState = func() (state.State, error) { return "test-state", nil } + h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } + h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } + + cache := &mockSessionCache{t: t, getReturnsToken: nil} + cacheKey := SessionCacheKey{ + Issuer: formPostSuccessServer.URL, + ClientID: "test-client-id", + Scopes: []string{"test-scope"}, + RedirectURI: "http://localhost:0/callback", + } + t.Cleanup(func() { + require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) + require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys) + require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens) + }) + require.NoError(t, WithSessionCache(cache)(h)) + require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h)) + + h.openURL = func(actualURL string) error { + parsedActualURL, err := url.Parse(actualURL) + require.NoError(t, err) + actualParams := parsedActualURL.Query() + + require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:") + actualParams.Del("redirect_uri") + + require.Equal(t, url.Values{ + // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: + // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 + // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g + "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge_method": []string{"S256"}, + "response_type": []string{"code"}, + "response_mode": []string{"form_post"}, + "scope": []string{"test-scope"}, + "nonce": []string{"test-nonce"}, + "state": []string{"test-state"}, + "access_type": []string{"offline"}, + "client_id": []string{"test-client-id"}, + }, actualParams) + + parsedActualURL.RawQuery = "" + require.Equal(t, formPostSuccessServer.URL+"/authorize", parsedActualURL.String()) + + go func() { + h.callbacks <- callbackResult{token: &testToken} + }() + return nil + } + return nil + } + }, + issuer: formPostSuccessServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""}, + wantToken: &testToken, + }, { name: "upstream name and type are included in authorize request if upstream name is provided", clientID: "test-client-id", @@ -650,7 +808,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForValue = func(promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { require.Equal(t, "Username: ", promptLabel) return "", errors.New("some prompt error") } @@ -667,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") } + h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") } return nil } }, @@ -853,11 +1011,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.generateState = func() (state.State, error) { return "test-state", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } - h.promptForValue = func(promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { require.Equal(t, "Username: ", promptLabel) return "some-upstream-username", nil } - h.promptForSecret = func(promptLabel string) (string, error) { + h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { require.Equal(t, "Password: ", promptLabel) return "some-upstream-password", nil } @@ -1287,10 +1445,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo WithContext(context.Background()), WithListenPort(0), WithScopes([]string{"test-scope"}), + WithSkipBrowserOpen(), tt.opt(t), WithLogger(testLogger), ) - require.Equal(t, tt.wantLogs, testLogger.Lines()) + testLogger.Expect(tt.wantLogs) if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) require.Nil(t, tok) @@ -1324,13 +1483,152 @@ func TestLogin(t *testing.T) { // nolint:gocyclo } } +func TestHandlePasteCallback(t *testing.T) { + const testRedirectURI = "http://127.0.0.1:12324/callback" + + tests := []struct { + name string + opt func(t *testing.T) Option + wantCallback *callbackResult + }{ + { + name: "no stdin available", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { + require.Equal(t, syscall.Stdin, fd) + return false + } + h.useFormPost = true + return nil + } + }, + }, + { + name: "no form_post mode available", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = false + return nil + } + }, + }, + { + name: "prompt fails", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = true + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel) + return "", fmt.Errorf("some prompt error") + } + return nil + } + }, + wantCallback: &callbackResult{ + err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"), + }, + }, + { + name: "redeeming code fails", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = true + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "invalid", nil + } + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(nil, fmt.Errorf("some exchange error")) + return mock + } + return nil + } + }, + wantCallback: &callbackResult{ + err: fmt.Errorf("some exchange error"), + }, + }, + { + name: "success", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = true + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "valid", nil + } + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil) + return mock + } + return nil + } + }, + wantCallback: &callbackResult{ + token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + h := &handlerState{ + callbacks: make(chan callbackResult, 1), + state: state.State("test-state"), + pkce: pkce.Code("test-pkce"), + nonce: nonce.Nonce("test-nonce"), + } + if tt.opt != nil { + require.NoError(t, tt.opt(t)(h)) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var buf bytes.Buffer + h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf) + require.Equal(t, + "Log in by visiting this link:\n\n https://test-authorize-url/\n\n", + buf.String(), + ) + + if tt.wantCallback != nil { + select { + case <-time.After(1 * time.Second): + require.Fail(t, "timed out waiting to receive from callbacks channel") + case result := <-h.callbacks: + require.Equal(t, *tt.wantCallback, result) + } + } + }) + } +} + func TestHandleAuthCodeCallback(t *testing.T) { const testRedirectURI = "http://127.0.0.1:12324/callback" + withFormPostMode := func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + return nil + } + } tests := []struct { name string method string query string + body []byte + contentType string opt func(t *testing.T) Option wantErr string wantHTTPStatus int @@ -1342,6 +1640,24 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantErr: "wanted GET", wantHTTPStatus: http.StatusMethodNotAllowed, }, + { + name: "wrong method for form_post", + method: "GET", + query: "", + opt: withFormPostMode, + wantErr: "wanted POST", + wantHTTPStatus: http.StatusMethodNotAllowed, + }, + { + name: "invalid form for form_post", + method: "POST", + query: "", + contentType: "application/x-www-form-urlencoded", + body: []byte(`%`), + opt: withFormPostMode, + wantErr: `invalid form: invalid URL escape "%"`, + wantHTTPStatus: http.StatusBadRequest, + }, { name: "invalid state", query: "state=invalid", @@ -1396,6 +1712,26 @@ func TestHandleAuthCodeCallback(t *testing.T) { } }, }, + { + name: "valid form_post", + method: http.MethodPost, + contentType: "application/x-www-form-urlencoded", + body: []byte(`state=test-state&code=valid`), + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil) + return mock + } + return nil + } + }, + }, } for _, tt := range tests { tt := tt @@ -1414,12 +1750,15 @@ func TestHandleAuthCodeCallback(t *testing.T) { defer cancel() resp := httptest.NewRecorder() - req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body)) require.NoError(t, err) req.URL.RawQuery = tt.query if tt.method != "" { req.Method = tt.method } + if tt.contentType != "" { + req.Header.Set("Content-Type", tt.contentType) + } err = h.handleAuthCodeCallback(resp, req) if tt.wantErr != "" { diff --git a/test/integration/cli_test.go b/test/integration/cli_test.go index dbba13aa..2e69bc32 100644 --- a/test/integration/cli_test.go +++ b/test/integration/cli_test.go @@ -307,16 +307,15 @@ func runPinnipedLoginOIDC( reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr)) scanner := bufio.NewScanner(reader) - const prompt = "Please log in: " for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, prompt) { - loginURLChan <- strings.TrimPrefix(line, prompt) + loginURL, err := url.Parse(strings.TrimSpace(scanner.Text())) + if err == nil && loginURL.Scheme == "https" { + loginURLChan <- loginURL.String() return nil } } - return fmt.Errorf("expected stderr to contain %s", prompt) + return fmt.Errorf("expected stderr to contain login URL") }) // Start a background goroutine to read stdout from the CLI and parse out an ExecCredential. diff --git a/test/integration/e2e_test.go b/test/integration/e2e_test.go index e4176d32..5102d59d 100644 --- a/test/integration/e2e_test.go +++ b/test/integration/e2e_test.go @@ -109,7 +109,7 @@ func TestE2EFullIntegration(t *testing.T) { }) // Add an OIDC upstream IDP and try using it to authenticate during kubectl commands. - t.Run("with Supervisor OIDC upstream IDP", func(t *testing.T) { + t.Run("with Supervisor OIDC upstream IDP and automatic flow", func(t *testing.T) { expectedUsername := env.SupervisorUpstreamOIDC.Username expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups @@ -195,16 +195,15 @@ func TestE2EFullIntegration(t *testing.T) { }() reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe)) - line, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("could not read login URL line from stderr: %w", err) + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + loginURL, err := url.Parse(strings.TrimSpace(scanner.Text())) + if err == nil && loginURL.Scheme == "https" { + loginURLChan <- loginURL.String() + return nil + } } - const prompt = "Please log in: " - if !strings.HasPrefix(line, prompt) { - return fmt.Errorf("expected %q to have prefix %q", line, prompt) - } - loginURLChan <- strings.TrimPrefix(line, prompt) - return readAndExpectEmpty(reader) + return fmt.Errorf("expected stderr to contain login URL") }) // Start a background goroutine to read stdout from kubectl and return the result as a string. @@ -242,17 +241,13 @@ func TestE2EFullIntegration(t *testing.T) { // Expect to be redirected to the upstream provider and log in. browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC) - // Expect to be redirected to the localhost callback. - t.Logf("waiting for redirect to callback") - browsertest.WaitForURL(t, page, regexp.MustCompile(`\Ahttp://127\.0\.0\.1:[0-9]+/callback\?.+\z`)) + // Expect to be redirected to the downstream callback which is serving the form_post HTML. + t.Logf("waiting for response page %s", downstream.Spec.Issuer) + browsertest.WaitForURL(t, page, regexp.MustCompile(regexp.QuoteMeta(downstream.Spec.Issuer))) - // Wait for the "pre" element that gets rendered for a `text/plain` page, and - // assert that it contains the success message. - t.Logf("verifying success page") - browsertest.WaitForVisibleElements(t, page, "pre") - msg, err := page.First("pre").Text() - require.NoError(t, err) - require.Equal(t, "you have been logged in and may now close this tab", msg) + // The response page should have done the background fetch() and POST'ed to the CLI's callback. + // It should now be in the "success" state. + formpostExpectSuccessState(t, page) // Expect the CLI to output a list of namespaces in JSON format. t.Logf("waiting for kubectl to output namespace list JSON") @@ -275,6 +270,113 @@ func TestE2EFullIntegration(t *testing.T) { ) }) + t.Run("with Supervisor OIDC upstream IDP and manual flow", func(t *testing.T) { + expectedUsername := env.SupervisorUpstreamOIDC.Username + expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups + + // Create a ClusterRoleBinding to give our test user from the upstream read-only access to the cluster. + testlib.CreateTestClusterRoleBinding(t, + rbacv1.Subject{Kind: rbacv1.UserKind, APIGroup: rbacv1.GroupName, Name: expectedUsername}, + rbacv1.RoleRef{Kind: "ClusterRole", APIGroup: rbacv1.GroupName, Name: "view"}, + ) + testlib.WaitForUserToHaveAccess(t, expectedUsername, []string{}, &authorizationv1.ResourceAttributes{ + Verb: "get", + Group: "", + Version: "v1", + Resource: "namespaces", + }) + + // Create upstream OIDC provider and wait for it to become ready. + testlib.CreateTestOIDCIdentityProvider(t, idpv1alpha1.OIDCIdentityProviderSpec{ + Issuer: env.SupervisorUpstreamOIDC.Issuer, + TLS: &idpv1alpha1.TLSSpec{ + CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamOIDC.CABundle)), + }, + AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{ + AdditionalScopes: env.SupervisorUpstreamOIDC.AdditionalScopes, + }, + Claims: idpv1alpha1.OIDCClaims{ + Username: env.SupervisorUpstreamOIDC.UsernameClaim, + Groups: env.SupervisorUpstreamOIDC.GroupsClaim, + }, + Client: idpv1alpha1.OIDCClient{ + SecretName: testlib.CreateClientCredsSecret(t, env.SupervisorUpstreamOIDC.ClientID, env.SupervisorUpstreamOIDC.ClientSecret).Name, + }, + }, idpv1alpha1.PhaseReady) + + // Use a specific session cache for this test. + sessionCachePath := tempDir + "/oidc-test-sessions-manual.yaml" + kubeconfigPath := runPinnipedGetKubeconfig(t, env, pinnipedExe, tempDir, []string{ + "get", "kubeconfig", + "--concierge-api-group-suffix", env.APIGroupSuffix, + "--concierge-authenticator-type", "jwt", + "--concierge-authenticator-name", authenticator.Name, + "--oidc-skip-browser", + "--oidc-skip-listen", + "--oidc-ca-bundle", testCABundlePath, + "--oidc-session-cache", sessionCachePath, + }) + + // Run "kubectl get namespaces" which should trigger a browser login via the plugin. + start := time.Now() + kubectlCmd := exec.CommandContext(ctx, "kubectl", "get", "namespace", "--kubeconfig", kubeconfigPath) + kubectlCmd.Env = append(os.Environ(), env.ProxyEnv()...) + + ptyFile, err := pty.Start(kubectlCmd) + require.NoError(t, err) + + // Wait for the subprocess to print the login prompt. + t.Logf("waiting for CLI to output login URL and manual prompt") + output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ") + require.Contains(t, output, "Log in by visiting this link:") + require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ") + + // Find the line with the login URL. + var loginURL string + for _, line := range strings.Split(output, "\n") { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "https://") { + loginURL = trimmed + } + } + require.NotEmptyf(t, loginURL, "didn't find login URL in output: %s", output) + + t.Logf("navigating to login page") + require.NoError(t, page.Navigate(loginURL)) + + // Expect to be redirected to the upstream provider and log in. + browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC) + + // Expect to be redirected to the downstream callback which is serving the form_post HTML. + t.Logf("waiting for response page %s", downstream.Spec.Issuer) + browsertest.WaitForURL(t, page, regexp.MustCompile(regexp.QuoteMeta(downstream.Spec.Issuer))) + + // The response page should have failed to automatically post, and should now be showing the manual instructions. + authCode := formpostExpectManualState(t, page) + + // Enter the auth code in the waiting prompt, followed by a newline. + t.Logf("'manually' pasting authorization code %q to waiting prompt", authCode) + _, err = ptyFile.WriteString(authCode + "\n") + require.NoError(t, err) + + // Read all of the remaining output from the subprocess until EOF. + t.Logf("waiting for kubectl to output namespace list") + remainingOutput, _ := ioutil.ReadAll(ptyFile) + // Ignore any errors returned because there is always an error on linux. + require.Greaterf(t, len(remainingOutput), 0, "expected to get some more output from the kubectl subcommand, but did not") + require.Greaterf(t, len(strings.Split(string(remainingOutput), "\n")), 2, "expected some namespaces to be returned, got %q", string(remainingOutput)) + t.Logf("first kubectl command took %s", time.Since(start).String()) + + requireUserCanUseKubectlWithoutAuthenticatingAgain(ctx, t, env, + downstream, + kubeconfigPath, + sessionCachePath, + pinnipedExe, + expectedUsername, + expectedGroups, + ) + }) + // Add an LDAP upstream IDP and try using it to authenticate during kubectl commands. t.Run("with Supervisor LDAP upstream IDP", func(t *testing.T) { if len(env.ToolsNamespace) == 0 && !env.HasCapability(testlib.CanReachInternetLDAPPorts) { @@ -376,7 +478,7 @@ func TestE2EFullIntegration(t *testing.T) { }) } -func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) { +func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) string { readFromFile := "" testlib.RequireEventuallyWithoutError(t, func() (bool, error) { @@ -390,6 +492,7 @@ func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) { } return false, nil // keep waiting and reading }, 1*time.Minute, 1*time.Second) + return readFromFile } func readAvailableOutput(t *testing.T, r io.Reader) (string, bool) { diff --git a/test/integration/formposthtml_test.go b/test/integration/formposthtml_test.go new file mode 100644 index 00000000..f44e1ae5 --- /dev/null +++ b/test/integration/formposthtml_test.go @@ -0,0 +1,257 @@ +// Copyright 2021 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package integration + +import ( + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + "time" + + "github.com/ory/fosite" + "github.com/ory/fosite/token/hmac" + "github.com/sclevine/agouti" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/httputil/securityheader" + "go.pinniped.dev/internal/oidc/provider/formposthtml" + "go.pinniped.dev/test/testlib" + "go.pinniped.dev/test/testlib/browsertest" +) + +func TestFormPostHTML(t *testing.T) { + // Run a mock callback handler, simulating the one running in the CLI. + callbackURL, expectCallback := formpostCallbackServer(t) + + // Open a single browser for all subtests to use (in sequence). + page := browsertest.Open(t) + + t.Run("success", func(t *testing.T) { + // Serve the form_post template with successful parameters. + responseParams := formpostRandomParams(t) + formpostInitiate(t, page, formpostTemplateServer(t, callbackURL, responseParams)) + + // Now we handle the callback and assert that we got what we expected. This should transition + // the UI into the success state. + expectCallback(t, responseParams) + formpostExpectSuccessState(t, page) + }) + + t.Run("callback server error", func(t *testing.T) { + // Serve the form_post template with a redirect URI that will return an HTTP 500 response. + responseParams := formpostRandomParams(t) + formpostInitiate(t, page, formpostTemplateServer(t, callbackURL+"?fail=500", responseParams)) + + // Now we handle the callback and assert that we got what we expected. + expectCallback(t, responseParams) + + // This is not 100% the behavior we'd like, but because our JS is making + // a cross-origin fetch() without CORS, we don't get to know anything + // about the response (even whether it is 200 vs. 500), so this case + // is the same as the success case. + // + // This case is fairly unlikely in practice, and if the CLI encounters + // an error it can also expose it via stderr anyway. + formpostExpectSuccessState(t, page) + }) + + t.Run("network failure", func(t *testing.T) { + // Serve the form_post template with a redirect URI that will return a network error. + responseParams := formpostRandomParams(t) + formpostInitiate(t, page, formpostTemplateServer(t, callbackURL+"?fail=close", responseParams)) + + // Now we handle the callback and assert that we got what we expected. + // This will trigger the callback server to close the client connection abruptly because + // of the `?fail=close` parameter above. + expectCallback(t, responseParams) + + // This failure should cause the UI to enter the "manual" state. + actualCode := formpostExpectManualState(t, page) + require.Equal(t, responseParams.Get("code"), actualCode) + }) + + t.Run("timeout", func(t *testing.T) { + // Serve the form_post template with successful parameters. + responseParams := formpostRandomParams(t) + formpostInitiate(t, page, formpostTemplateServer(t, callbackURL, responseParams)) + + // Sleep for longer than the two second timeout. + // During this sleep we are blocking the callback from returning. + time.Sleep(3 * time.Second) + + // Assert that the timeout fires and we see the manual instructions. + actualCode := formpostExpectManualState(t, page) + require.Equal(t, responseParams.Get("code"), actualCode) + + // Now simulate the callback finally succeeding, in which case + // the manual instructions should disappear and we should see the success + // div instead. + expectCallback(t, responseParams) + formpostExpectSuccessState(t, page) + }) +} + +// formpostCallbackServer runs a test server that simulates the CLI's callback handler. +// It returns the URL of the running test server and a function for fetching the next +// received form POST parameters. +// +// The test server supports special `?fail=close` and `?fail=500` to force error cases. +func formpostCallbackServer(t *testing.T) (string, func(*testing.T, url.Values)) { + results := make(chan url.Values) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.NoError(t, r.ParseForm()) + + // Extract only the POST parameters (r.Form also contains URL query parameters). + postParams := url.Values{} + for k := range r.Form { + if v := r.PostFormValue(k); v != "" { + postParams.Set(k, v) + } + } + + // Send the form parameters back on the results channel, giving up if the + // request context is cancelled (such as if the client disconnects). + select { + case results <- postParams: + case <-r.Context().Done(): + return + } + + switch r.URL.Query().Get("fail") { + case "close": // If "fail=close" is passed, close the connection immediately. + if conn, _, err := w.(http.Hijacker).Hijack(); err == nil { + _ = conn.Close() + } + return + case "500": // If "fail=500" is passed, return a 500 error. + w.WriteHeader(http.StatusInternalServerError) + return + } + })) + t.Cleanup(func() { + close(results) + server.Close() + }) + return server.URL, func(t *testing.T, expected url.Values) { + t.Logf("expecting to get a POST callback...") + select { + case actual := <-results: + require.Equal(t, expected, actual, "did not receive expected callback") + case <-time.After(3 * time.Second): + t.Errorf("failed to receive expected callback %v", expected) + t.FailNow() + } + } +} + +// formpostTemplateServer runs a test server that serves formposthtml.Template() rendered with test parameters. +func formpostTemplateServer(t *testing.T, redirectURI string, responseParams url.Values) string { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fosite.WriteAuthorizeFormPostResponse(redirectURI, responseParams, formposthtml.Template(), w) + }) + server := httptest.NewServer(securityheader.WrapWithCustomCSP( + handler, + formposthtml.ContentSecurityPolicy(), + )) + t.Cleanup(server.Close) + return server.URL +} + +// formpostRandomParams is a helper to generate random OAuth2 response parameters for testing. +func formpostRandomParams(t *testing.T) url.Values { + generator := &hmac.HMACStrategy{GlobalSecret: testlib.RandBytes(t, 32), TokenEntropy: 32} + authCode, _, err := generator.Generate() + require.NoError(t, err) + return url.Values{ + "code": []string{authCode}, + "scope": []string{"openid offline_access pinniped:request-audience"}, + "state": []string{testlib.RandHex(t, 16)}, + } +} + +// formpostExpectTitle asserts that the page has the expected title. +func formpostExpectTitle(t *testing.T, page *agouti.Page, expected string) { + actual, err := page.Title() + require.NoError(t, err) + require.Equal(t, expected, actual) +} + +// formpostExpectTitle asserts that the page has the expected SVG/emoji favicon. +func formpostExpectFavicon(t *testing.T, page *agouti.Page, expected string) { + iconURL, err := page.First("#favicon").Attribute("href") + require.NoError(t, err) + require.True(t, strings.HasPrefix(iconURL, "data:image/svg+xml,