s.
+ Array.from(document.querySelectorAll('.state')).forEach(e => e.hidden = true);
+
+ // Unhide the current state
.
+ const currentDiv = document.getElementById(id)
+ currentDiv.hidden = false;
+
+ // Set the window title.
+ document.title = currentDiv.dataset.title;
+
+ // Set the favicon using inline SVG (does not work on Safari).
+ document.getElementById('favicon').setAttribute(
+ 'href',
+ 'data:image/svg+xml,
'
+ );
+ }
+
+ // At load, show the spinner, hide the other divs, set the favicon, and
+ // replace the URL path with './' so the upstream auth code disappears.
+ transitionToState('loading');
+ window.history.replaceState(null, '', './');
+
+ // When the copy button is clicked, copy to the clipboard.
+ document.getElementById('manual-copy-button').onclick = () => {
+ const code = document.getElementById('manual-copy-button').innerText;
+ navigator.clipboard.writeText(code)
+ .then(() => console.info('copied authorization code ' + code + ' to clipboard'))
+ .catch(e => console.error('failed to copy code ' + code + ' to clipboard: ' + e));
+ };
+
+ // Set a timeout to transition to the "manual" state if nothing succeeds within 2s.
+ const timeout = setTimeout(() => transitionToState('manual'), 2000);
+
+ // Try to submit the POST callback, handling the success and error cases.
+ const responseParams = document.forms[0].elements;
+ fetch(
+ responseParams['redirect_uri'].value,
+ {
+ method: 'POST',
+ mode: 'no-cors',
+ headers: {'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8'},
+ body: responseParams['encoded_params'].value,
+ })
+ .then(() => clearTimeout(timeout))
+ .then(() => transitionToState('success'))
+ .catch(() => transitionToState('manual'));
+};
diff --git a/internal/oidc/provider/formposthtml/formposthtml.go b/internal/oidc/provider/formposthtml/formposthtml.go
new file mode 100644
index 00000000..4eeebf74
--- /dev/null
+++ b/internal/oidc/provider/formposthtml/formposthtml.go
@@ -0,0 +1,65 @@
+// Copyright 2021 the Pinniped contributors. All Rights Reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+// Package formposthtml defines HTML templates used by the Supervisor.
+//nolint: gochecknoglobals // This package uses globals to ensure that all parsing and minifying happens at init.
+package formposthtml
+
+import (
+ "crypto/sha256"
+ _ "embed" // Needed to trigger //go:embed directives below.
+ "encoding/base64"
+ "html/template"
+ "strings"
+
+ "github.com/tdewolff/minify/v2/minify"
+)
+
+var (
+ //go:embed form_post.css
+ rawCSS string
+ minifiedCSS = mustMinify(minify.CSS(rawCSS))
+
+ //go:embed form_post.js
+ rawJS string
+ minifiedJS = mustMinify(minify.JS(rawJS))
+
+ //go:embed form_post.gohtml
+ rawHTMLTemplate string
+)
+
+// Parse the Go templated HTML and inject functions providing the minified inline CSS and JS.
+var parsedHTMLTemplate = template.Must(template.New("form_post.gohtml").Funcs(template.FuncMap{
+ "minifiedCSS": func() template.CSS { return template.CSS(minifiedCSS) },
+ "minifiedJS": func() template.JS { return template.JS(minifiedJS) }, //nolint:gosec // This is 100% static input, not attacker-controlled.
+}).Parse(rawHTMLTemplate))
+
+// Generate the CSP header value once since it's effectively constant:
+var cspValue = strings.Join([]string{
+ `default-src 'none'`,
+ `script-src '` + cspHash(minifiedJS) + `'`,
+ `style-src '` + cspHash(minifiedCSS) + `'`,
+ `img-src data:`,
+ `connect-src *`,
+ `frame-ancestors 'none'`,
+}, "; ")
+
+func mustMinify(s string, err error) string {
+ if err != nil {
+ panic(err)
+ }
+ return s
+}
+
+func cspHash(s string) string {
+ hashBytes := sha256.Sum256([]byte(s))
+ return "sha256-" + base64.StdEncoding.EncodeToString(hashBytes[:])
+}
+
+// ContentSecurityPolicy returns the Content-Security-Policy header value to make the Template() operate correctly.
+//
+// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/default-src#:~:text=%27%3Chash-algorithm%3E-%3Cbase64-value%3E%27.
+func ContentSecurityPolicy() string { return cspValue }
+
+// Template returns the html/template.Template for rendering the response_type=form_post response page.
+func Template() *template.Template { return parsedHTMLTemplate }
diff --git a/internal/oidc/provider/formposthtml/formposthtml_test.go b/internal/oidc/provider/formposthtml/formposthtml_test.go
new file mode 100644
index 00000000..a8a1a929
--- /dev/null
+++ b/internal/oidc/provider/formposthtml/formposthtml_test.go
@@ -0,0 +1,101 @@
+// Copyright 2021 the Pinniped contributors. All Rights Reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+package formposthtml
+
+import (
+ "bytes"
+ "fmt"
+ "net/url"
+ "testing"
+
+ "github.com/ory/fosite"
+ "github.com/stretchr/testify/require"
+
+ "go.pinniped.dev/internal/here"
+)
+
+var (
+ testRedirectURL = "http://127.0.0.1:12345/callback"
+
+ testResponseParams = url.Values{
+ "code": []string{"test-S629KHsCCBYV0PQ6FDSrn6iEXtVImQRBh7NCAk.JezyUSdCiSslYjtUmv7V5VAgiCz3ZkES9mYldg9GhqU"},
+ "scope": []string{"openid offline_access pinniped:request-audience"},
+ "state": []string{"01234567890123456789012345678901"},
+ }
+
+ testExpectedFormPostOutput = here.Doc(`
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Login succeeded
+
You have successfully logged in. You may now close this tab.
+
+
+
Finish your login
+
To finish logging in, paste this authorization code into your command-line session:
+
+
+
+
+ `)
+
+ // It's okay if this changes in the future, but this gives us a chance to eyeball the formatting.
+ // Our browser-based integration tests should find any incompatibilities.
+ testExpectedCSP = `default-src 'none'; ` +
+ `script-src 'sha256-U+tKnJ2oMSYKSxmSX3V2mPBN8xdr9JpampKAhbSo108='; ` +
+ `style-src 'sha256-TLAQE3UR2KpwP7AzMCE4iPDizh7zLPx9UXeK5ntuoRg='; ` +
+ `img-src data:; ` +
+ `connect-src *; ` +
+ `frame-ancestors 'none'`
+)
+
+func TestTemplate(t *testing.T) {
+ // Use the Fosite helper to render the form, ensuring that the parameters all have the same names + types.
+ var buf bytes.Buffer
+ fosite.WriteAuthorizeFormPostResponse(testRedirectURL, testResponseParams, Template(), &buf)
+
+ // Render again so we can confirm that there is no error returned (Fosite ignores any error).
+ var buf2 bytes.Buffer
+ require.NoError(t, Template().Execute(&buf2, struct {
+ RedirURL string
+ Parameters url.Values
+ }{
+ RedirURL: testRedirectURL,
+ Parameters: testResponseParams,
+ }))
+
+ require.Equal(t, buf.String(), buf2.String())
+ require.Equal(t, testExpectedFormPostOutput, buf.String())
+}
+
+func TestContentSecurityPolicyHashes(t *testing.T) {
+ require.Equal(t, testExpectedCSP, ContentSecurityPolicy())
+}
+
+func TestHelpers(t *testing.T) {
+ // These are silly tests but it's easy to we might as well have them.
+ require.Equal(t, "test", mustMinify("test", nil))
+ require.PanicsWithError(t, "some error", func() { mustMinify("", fmt.Errorf("some error")) })
+
+ // Example test vector from https://content-security-policy.com/hash/.
+ require.Equal(t, "sha256-RFWPLDbv2BY+rCkDzsE+0fr8ylGr2R2faWMhq4lfEQc=", cspHash("doSomething();"))
+}
diff --git a/internal/testutil/assertions.go b/internal/testutil/assertions.go
index 54fc8563..9286bff1 100644
--- a/internal/testutil/assertions.go
+++ b/internal/testutil/assertions.go
@@ -1,4 +1,4 @@
-// Copyright 2020 the Pinniped contributors. All Rights Reserved.
+// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
@@ -55,7 +55,9 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret
}
func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
- require.Equal(t, "default-src 'none'; frame-ancestors 'none'", response.Header().Get("Content-Security-Policy"))
+ // This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use.
+ require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'")
+
require.Equal(t, "DENY", response.Header().Get("X-Frame-Options"))
require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))
require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options"))
diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go
index b8e7b0de..f690af52 100644
--- a/internal/testutil/oidctestutil/oidctestutil.go
+++ b/internal/testutil/oidctestutil/oidctestutil.go
@@ -235,10 +235,10 @@ func VerifyECDSAIDToken(
return token
}
-func RequireAuthcodeRedirectLocation(
+func RequireAuthCodeRegexpMatch(
t *testing.T,
- actualRedirectLocation string,
- wantRedirectLocationRegexp string,
+ actualContent string,
+ wantRegexp string,
kubeClient *fake.Clientset,
secretsClient v1.SecretInterface,
oauthStore fositestoragei.AllFositeStorage,
@@ -256,9 +256,9 @@ func RequireAuthcodeRedirectLocation(
t.Helper()
// Assert that Location header matches regular expression.
- regex := regexp.MustCompile(wantRedirectLocationRegexp)
- submatches := regex.FindStringSubmatch(actualRedirectLocation)
- require.Lenf(t, submatches, 2, "no regexp match in actualRedirectLocation: %q", actualRedirectLocation)
+ regex := regexp.MustCompile(wantRegexp)
+ submatches := regex.FindStringSubmatch(actualContent)
+ require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent)
capturedAuthCode := submatches[1]
// fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface
diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go
index fdf34eff..36b688e2 100644
--- a/pkg/oidcclient/login.go
+++ b/pkg/oidcclient/login.go
@@ -10,6 +10,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"mime"
"net"
"net/http"
@@ -17,6 +18,7 @@ import (
"os"
"sort"
"strings"
+ "syscall"
"time"
"github.com/coreos/go-oidc/v3/oidc"
@@ -87,6 +89,7 @@ type handlerState struct {
// Generated parameters of a login flow.
provider *oidc.Provider
oauth2Config *oauth2.Config
+ useFormPost bool
state state.State
nonce nonce.Nonce
pkce pkce.Code
@@ -96,10 +99,12 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error)
openURL func(string) error
+ listen func(string, string) (net.Listener, error)
+ isTTY func(int) bool
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
- promptForValue func(promptLabel string) (string, error)
- promptForSecret func(promptLabel string) (string, error)
+ promptForValue func(ctx context.Context, promptLabel string) (string, error)
+ promptForSecret func(ctx context.Context, promptLabel string) (string, error)
callbacks chan callbackResult
}
@@ -156,6 +161,9 @@ func WithScopes(scopes []string) Option {
// WithBrowserOpen overrides the default "open browser" functionality with a custom callback. If not specified,
// an implementation using https://github.com/pkg/browser will be used by default.
+//
+// Deprecated: this option will be removed in a future version of Pinniped. See the
+// WithSkipBrowserOpen() option instead.
func WithBrowserOpen(openURL func(url string) error) Option {
return func(h *handlerState) error {
h.openURL = openURL
@@ -163,6 +171,23 @@ func WithBrowserOpen(openURL func(url string) error) Option {
}
}
+// WithSkipBrowserOpen causes the login to only print the authorize URL, but skips attempting to
+// open the user's default web browser.
+func WithSkipBrowserOpen() Option {
+ return func(h *handlerState) error {
+ h.openURL = func(_ string) error { return nil }
+ return nil
+ }
+}
+
+// WithSkipListen causes the login skip starting the localhost listener, forcing the manual copy/paste login flow.
+func WithSkipListen() Option {
+ return func(h *handlerState) error {
+ h.listen = func(string, string) (net.Listener, error) { return nil, nil }
+ return nil
+ }
+}
+
// SessionCacheKey contains the data used to select a valid session cache entry.
type SessionCacheKey struct {
Issuer string `json:"issuer"`
@@ -250,6 +275,8 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generateNonce: nonce.Generate,
generatePKCE: pkce.Generate,
openURL: browser.OpenURL,
+ listen: net.Listen,
+ isTTY: term.IsTerminal,
getProvider: upstreamoidc.New,
validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token)
@@ -376,11 +403,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
// Ask the user for their username and password.
- username, err := h.promptForValue(defaultLDAPUsernamePrompt)
+ username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt)
if err != nil {
return nil, fmt.Errorf("error prompting for username: %w", err)
}
- password, err := h.promptForSecret(defaultLDAPPasswordPrompt)
+ password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
if err != nil {
return nil, fmt.Errorf("error prompting for password: %w", err)
}
@@ -475,30 +502,55 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (
// Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint.
// Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error.
func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
- // Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number).
- listener, err := net.Listen("tcp", h.listenAddr)
+ // Attempt to open a local TCP listener, logging but otherwise ignoring any error.
+ listener, err := h.listen("tcp", h.listenAddr)
if err != nil {
- return nil, fmt.Errorf("could not open callback listener: %w", err)
+ h.logger.V(debugLogLevel).Error(err, "could not open callback listener")
+ }
+
+ // If the listener failed to start and stdin is not a TTY, then we have no hope of succeeding,
+ // since we won't be able to receive the web callback and we can't prompt for the manual auth code.
+ if listener == nil && !h.isTTY(syscall.Stdin) {
+ return nil, fmt.Errorf("login failed: must have either a localhost listener or stdin must be a TTY")
+ }
+
+ // Update the OAuth2 redirect_uri to match the actual listener address (if there is one), or just use
+ // a fake ":0" port if there is no listener running.
+ redirectURI := url.URL{Scheme: "http", Path: h.callbackPath}
+ if listener == nil {
+ redirectURI.Host = "127.0.0.1:0"
+ } else {
+ redirectURI.Host = listener.Addr().String()
+ }
+ h.oauth2Config.RedirectURL = redirectURI.String()
+
+ // If the server supports it, request response_mode=form_post.
+ authParams := *authorizeOptions
+ if h.useFormPost {
+ authParams = append(authParams, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
- h.oauth2Config.RedirectURL = (&url.URL{
- Scheme: "http",
- Host: listener.Addr().String(),
- Path: h.callbackPath,
- }).String()
// Now that we have a redirect URL with the listener port, we can build the authorize URL.
- authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), *authorizeOptions...)
+ authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...)
- // Start a callback server in a background goroutine.
- shutdown := h.serve(listener)
- defer shutdown()
-
- // Open the authorize URL in the users browser.
- if err := h.openURL(authorizeURL); err != nil {
- return nil, fmt.Errorf("could not open browser: %w", err)
+ // If there is a listener running, start serving the callback handler in a background goroutine.
+ if listener != nil {
+ shutdown := h.serve(listener)
+ defer shutdown()
}
- // Wait for either the callback or a timeout.
+ // Open the authorize URL in the users browser, logging but otherwise ignoring any error.
+ if err := h.openURL(authorizeURL); err != nil {
+ h.logger.V(debugLogLevel).Error(err, "could not open browser")
+ }
+
+ ctx, cancel := context.WithCancel(h.ctx)
+ defer cancel()
+
+ // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
+ h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
+
+ // Wait for either the web callback, a pasted auth code, or a timeout.
select {
case <-h.ctx.Done():
return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err())
@@ -510,7 +562,37 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
}
}
-func promptForValue(promptLabel string) (string, error) {
+func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
+ _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
+
+ // If stdin is not a TTY, print the URL but don't prompt for the manual paste,
+ // since we have no way of reading it.
+ if !h.isTTY(syscall.Stdin) {
+ return
+ }
+
+ // If the server didn't support response_mode=form_post, don't bother prompting for the manual
+ // code because the user isn't going to have any easy way to manually copy it anyway.
+ if !h.useFormPost {
+ return
+ }
+
+ // Launch the manual auth code prompt in a background goroutine, which will be cancelled
+ // if the parent context is cancelled (when the login succeeds or times out).
+ go func() {
+ code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
+ if err != nil {
+ h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
+ return
+ }
+
+ // When a code is pasted, redeem it for a token and return that result on the callbacks channel.
+ token, err := h.redeemAuthCode(ctx, code)
+ h.callbacks <- callbackResult{token: token, err: err}
+ }()
+}
+
+func promptForValue(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal")
}
@@ -518,6 +600,15 @@ func promptForValue(promptLabel string) (string, error) {
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
+
+ // If the context is canceled, set the read deadline on stdin so the read immediately finishes.
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ go func() {
+ <-ctx.Done()
+ _ = os.Stdin.SetReadDeadline(time.Now())
+ }()
+
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", fmt.Errorf("could read input from stdin: %w", err)
@@ -526,7 +617,7 @@ func promptForValue(promptLabel string) (string, error) {
return text, nil
}
-func promptForSecret(promptLabel string) (string, error) {
+func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal")
}
@@ -534,17 +625,27 @@ func promptForSecret(promptLabel string) (string, error) {
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
- password, err := term.ReadPassword(0)
+
+ // If the context is canceled, set the read deadline on stdin so the read immediately finishes.
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ go func() {
+ <-ctx.Done()
+ _ = os.Stdin.SetReadDeadline(time.Now())
+
+ // term.ReadPassword swallows the newline that was typed by the user, so to
+ // avoid the next line of output from happening on same line as the password
+ // prompt, we need to print a newline.
+ //
+ // Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
+ // on stderr is formatted correctly.
+ _, _ = fmt.Fprint(os.Stderr, "\n")
+ }()
+
+ password, err := term.ReadPassword(syscall.Stdin)
if err != nil {
return "", fmt.Errorf("could not read password: %w", err)
}
- // term.ReadPassword swallows the newline that was typed by the user, so to
- // avoid the next line of output from happening on same line as the password
- // prompt, we need to print a newline.
- _, err = fmt.Fprint(os.Stderr, "\n")
- if err != nil {
- return "", fmt.Errorf("could not print newline to stderr: %w", err)
- }
return string(password), err
}
@@ -567,9 +668,27 @@ func (h *handlerState) initOIDCDiscovery() error {
Endpoint: h.provider.Endpoint(),
Scopes: h.scopes,
}
+
+ // Use response_mode=form_post if the provider supports it.
+ var discoveryClaims struct {
+ ResponseModesSupported []string `json:"response_modes_supported"`
+ }
+ if err := h.provider.Claims(&discoveryClaims); err != nil {
+ return fmt.Errorf("could not decode response_modes_supported in OIDC discovery from %q: %w", h.issuer, err)
+ }
+ h.useFormPost = stringSliceContains(discoveryClaims.ResponseModesSupported, "form_post")
return nil
}
+func stringSliceContains(slice []string, s string) bool {
+ for _, item := range slice {
+ if item == s {
+ return true
+ }
+ }
+ return false
+}
+
func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) {
h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience)
// Perform OIDC discovery. This may have already been performed if there was not a cached base token.
@@ -664,13 +783,29 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}
}()
- // Return HTTP 405 for anything that's not a GET.
- if r.Method != http.MethodGet {
- return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
+ var params url.Values
+ if h.useFormPost {
+ // Return HTTP 405 for anything that's not a POST.
+ if r.Method != http.MethodPost {
+ return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST")
+ }
+
+ // Parse and pull the response parameters from a application/x-www-form-urlencoded request body.
+ if err := r.ParseForm(); err != nil {
+ return httperr.Wrap(http.StatusBadRequest, "invalid form", err)
+ }
+ params = r.Form
+ } else {
+ // Return HTTP 405 for anything that's not a GET.
+ if r.Method != http.MethodGet {
+ return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
+ }
+
+ // Pull response parameters from the URL query string.
+ params = r.URL.Query()
}
// Validate OAuth2 state and fail if it's incorrect (to block CSRF).
- params := r.URL.Query()
if err := h.state.Validate(params.Get("state")); err != nil {
return httperr.New(http.StatusForbidden, "missing or invalid state parameter")
}
@@ -685,14 +820,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
// Exchange the authorization code for access, ID, and refresh tokens and perform required
// validations on the returned ID token.
- token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).
- ExchangeAuthcodeAndValidateTokens(
- r.Context(),
- params.Get("code"),
- h.pkce,
- h.nonce,
- h.oauth2Config.RedirectURL,
- )
+ token, err := h.redeemAuthCode(r.Context(), params.Get("code"))
if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
}
@@ -702,6 +830,17 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
return nil
}
+func (h *handlerState) redeemAuthCode(ctx context.Context, code string) (*oidctypes.Token, error) {
+ return h.getProvider(h.oauth2Config, h.provider, h.httpClient).
+ ExchangeAuthcodeAndValidateTokens(
+ ctx,
+ code,
+ h.pkce,
+ h.nonce,
+ h.oauth2Config.RedirectURL,
+ )
+}
+
func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go
index d85c01ac..358becd3 100644
--- a/pkg/oidcclient/login_test.go
+++ b/pkg/oidcclient/login_test.go
@@ -4,15 +4,18 @@
package oidcclient
import (
+ "bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
+ "net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
+ "syscall"
"testing"
"time"
@@ -80,6 +83,22 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
}))
t.Cleanup(errorServer.Close)
+ // Start a test server that returns discovery data with a broken response_modes_supported value.
+ brokenResponseModeMux := http.NewServeMux()
+ brokenResponseModeServer := httptest.NewServer(brokenResponseModeMux)
+ brokenResponseModeMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("content-type", "application/json")
+ type providerJSON struct {
+ Issuer string `json:"issuer"`
+ ResponseModesSupported string `json:"response_modes_supported"` // Wrong type (should be []string).
+ }
+ _ = json.NewEncoder(w).Encode(&providerJSON{
+ Issuer: brokenResponseModeServer.URL,
+ ResponseModesSupported: "invalid",
+ })
+ })
+ t.Cleanup(brokenResponseModeServer.Close)
+
// Start a test server that returns discovery data with a broken token URL
brokenTokenURLMux := http.NewServeMux()
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
@@ -100,30 +119,29 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
})
t.Cleanup(brokenTokenURLServer.Close)
- // Start a test server that returns a real discovery document and answers refresh requests.
- providerMux := http.NewServeMux()
- successServer := httptest.NewServer(providerMux)
- t.Cleanup(successServer.Close)
- providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
- return
+ discoveryHandler := func(server *httptest.Server, responseModes []string) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
+ return
+ }
+ w.Header().Set("content-type", "application/json")
+ _ = json.NewEncoder(w).Encode(&struct {
+ Issuer string `json:"issuer"`
+ AuthURL string `json:"authorization_endpoint"`
+ TokenURL string `json:"token_endpoint"`
+ JWKSURL string `json:"jwks_uri"`
+ ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
+ }{
+ Issuer: server.URL,
+ AuthURL: server.URL + "/authorize",
+ TokenURL: server.URL + "/token",
+ JWKSURL: server.URL + "/keys",
+ ResponseModesSupported: responseModes,
+ })
}
- w.Header().Set("content-type", "application/json")
- type providerJSON struct {
- Issuer string `json:"issuer"`
- AuthURL string `json:"authorization_endpoint"`
- TokenURL string `json:"token_endpoint"`
- JWKSURL string `json:"jwks_uri"`
- }
- _ = json.NewEncoder(w).Encode(&providerJSON{
- Issuer: successServer.URL,
- AuthURL: successServer.URL + "/authorize",
- TokenURL: successServer.URL + "/token",
- JWKSURL: successServer.URL + "/keys",
- })
- })
- providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
+ }
+ tokenHandler := func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
@@ -204,7 +222,21 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
- })
+ }
+
+ // Start a test server that returns a real discovery document and answers refresh requests.
+ providerMux := http.NewServeMux()
+ successServer := httptest.NewServer(providerMux)
+ t.Cleanup(successServer.Close)
+ providerMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(successServer, nil))
+ providerMux.HandleFunc("/token", tokenHandler)
+
+ // Start a test server that returns a real discovery document and answers refresh requests, _and_ supports form_mode=post.
+ formPostProviderMux := http.NewServeMux()
+ formPostSuccessServer := httptest.NewServer(formPostProviderMux)
+ t.Cleanup(formPostSuccessServer.Close)
+ formPostProviderMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(formPostSuccessServer, []string{"query", "form_post"}))
+ formPostProviderMux.HandleFunc("/token", tokenHandler)
defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam
// Call the handler function from the test server to calculate the response.
@@ -218,8 +250,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
- h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil }
- h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil }
+ h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
+ h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
@@ -349,7 +381,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantToken: &testToken,
},
{
- name: "discovery failure",
+ name: "discovery failure due to 500 error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil }
},
@@ -357,6 +389,15 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""},
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
+ {
+ name: "discovery failure due to invalid response_modes_supported",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error { return nil }
+ },
+ issuer: brokenResponseModeServer.URL,
+ wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenResponseModeServer.URL + "\""},
+ wantErr: fmt.Sprintf("could not decode response_modes_supported in OIDC discovery from %q: json: cannot unmarshal string into Go struct field .response_modes_supported of type []string", brokenResponseModeServer.URL),
+ },
{
name: "session cache hit with refreshable token",
issuer: successServer.URL,
@@ -451,38 +492,93 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
})
h.cache = cache
- h.listenAddr = "invalid-listen-address"
-
+ h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
+ h.isTTY = func(int) bool { return false }
return nil
}
},
- wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"",
- "\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""},
+ wantLogs: []string{
+ `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
+ `"level"=4 "msg"="Pinniped: Refreshing cached token."`,
+ `"msg"="could not open callback listener" "error"="some listen error"`,
+ },
// Expect this to fall through to the authorization code flow, so it fails here.
- wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
+ wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
},
{
- name: "listen failure",
+ name: "listen failure and non-tty stdin",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
- h.listenAddr = "invalid-listen-address"
+ h.listen = func(net string, addr string) (net.Listener, error) {
+ assert.Equal(t, "tcp", net)
+ assert.Equal(t, "localhost:0", addr)
+ return nil, fmt.Errorf("some listen error")
+ }
+ h.isTTY = func(fd int) bool {
+ assert.Equal(t, fd, syscall.Stdin)
+ return false
+ }
return nil
}
},
- issuer: successServer.URL,
- wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
- wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
+ issuer: successServer.URL,
+ wantLogs: []string{
+ `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
+ `"msg"="could not open callback listener" "error"="some listen error"`,
+ },
+ wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
},
{
- name: "browser open failure",
+ name: "listening disabled and manual prompt fails",
opt: func(t *testing.T) Option {
- return WithBrowserOpen(func(url string) error {
- return fmt.Errorf("some browser open error")
- })
+ return func(h *handlerState) error {
+ require.NoError(t, WithSkipListen()(h))
+ h.isTTY = func(fd int) bool { return true }
+ h.openURL = func(authorizeURL string) error {
+ parsed, err := url.Parse(authorizeURL)
+ require.NoError(t, err)
+ require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
+ require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
+ return fmt.Errorf("some browser open error")
+ }
+ h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
+ return "", fmt.Errorf("some prompt error")
+ }
+ return nil
+ }
},
- issuer: successServer.URL,
- wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
- wantErr: "could not open browser: some browser open error",
+ issuer: formPostSuccessServer.URL,
+ wantLogs: []string{
+ `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
+ `"msg"="could not open browser" "error"="some browser open error"`,
+ },
+ wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
+ },
+ {
+ name: "listen success and manual prompt succeeds",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
+ h.isTTY = func(fd int) bool { return true }
+ h.openURL = func(authorizeURL string) error {
+ parsed, err := url.Parse(authorizeURL)
+ require.NoError(t, err)
+ require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
+ require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
+ return nil
+ }
+ h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
+ return "", fmt.Errorf("some prompt error")
+ }
+ return nil
+ }
+ },
+ issuer: formPostSuccessServer.URL,
+ wantLogs: []string{
+ `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
+ `"msg"="could not open callback listener" "error"="some listen error"`,
+ },
+ wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "timeout waiting for callback",
@@ -580,6 +676,68 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken,
},
+ {
+ name: "callback returns success with request_mode=form_post",
+ clientID: "test-client-id",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.generateState = func() (state.State, error) { return "test-state", nil }
+ h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
+ h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
+
+ cache := &mockSessionCache{t: t, getReturnsToken: nil}
+ cacheKey := SessionCacheKey{
+ Issuer: formPostSuccessServer.URL,
+ ClientID: "test-client-id",
+ Scopes: []string{"test-scope"},
+ RedirectURI: "http://localhost:0/callback",
+ }
+ t.Cleanup(func() {
+ require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
+ require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
+ require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
+ })
+ require.NoError(t, WithSessionCache(cache)(h))
+ require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
+
+ h.openURL = func(actualURL string) error {
+ parsedActualURL, err := url.Parse(actualURL)
+ require.NoError(t, err)
+ actualParams := parsedActualURL.Query()
+
+ require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
+ actualParams.Del("redirect_uri")
+
+ require.Equal(t, url.Values{
+ // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
+ // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
+ // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
+ "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
+ "code_challenge_method": []string{"S256"},
+ "response_type": []string{"code"},
+ "response_mode": []string{"form_post"},
+ "scope": []string{"test-scope"},
+ "nonce": []string{"test-nonce"},
+ "state": []string{"test-state"},
+ "access_type": []string{"offline"},
+ "client_id": []string{"test-client-id"},
+ }, actualParams)
+
+ parsedActualURL.RawQuery = ""
+ require.Equal(t, formPostSuccessServer.URL+"/authorize", parsedActualURL.String())
+
+ go func() {
+ h.callbacks <- callbackResult{token: &testToken}
+ }()
+ return nil
+ }
+ return nil
+ }
+ },
+ issuer: formPostSuccessServer.URL,
+ wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
+ wantToken: &testToken,
+ },
{
name: "upstream name and type are included in authorize request if upstream name is provided",
clientID: "test-client-id",
@@ -650,7 +808,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
- h.promptForValue = func(promptLabel string) (string, error) {
+ h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error")
}
@@ -667,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
- h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") }
+ h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
@@ -853,11 +1011,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
- h.promptForValue = func(promptLabel string) (string, error) {
+ h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
- h.promptForSecret = func(promptLabel string) (string, error) {
+ h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}
@@ -1287,10 +1445,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
WithContext(context.Background()),
WithListenPort(0),
WithScopes([]string{"test-scope"}),
+ WithSkipBrowserOpen(),
tt.opt(t),
WithLogger(testLogger),
)
- require.Equal(t, tt.wantLogs, testLogger.Lines())
+ testLogger.Expect(tt.wantLogs)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok)
@@ -1324,13 +1483,152 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
}
}
+func TestHandlePasteCallback(t *testing.T) {
+ const testRedirectURI = "http://127.0.0.1:12324/callback"
+
+ tests := []struct {
+ name string
+ opt func(t *testing.T) Option
+ wantCallback *callbackResult
+ }{
+ {
+ name: "no stdin available",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.isTTY = func(fd int) bool {
+ require.Equal(t, syscall.Stdin, fd)
+ return false
+ }
+ h.useFormPost = true
+ return nil
+ }
+ },
+ },
+ {
+ name: "no form_post mode available",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.isTTY = func(fd int) bool { return true }
+ h.useFormPost = false
+ return nil
+ }
+ },
+ },
+ {
+ name: "prompt fails",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.isTTY = func(fd int) bool { return true }
+ h.useFormPost = true
+ h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
+ assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
+ return "", fmt.Errorf("some prompt error")
+ }
+ return nil
+ }
+ },
+ wantCallback: &callbackResult{
+ err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"),
+ },
+ },
+ {
+ name: "redeeming code fails",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.isTTY = func(fd int) bool { return true }
+ h.useFormPost = true
+ h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
+ return "invalid", nil
+ }
+ h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
+ h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
+ mock := mockUpstream(t)
+ mock.EXPECT().
+ ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
+ Return(nil, fmt.Errorf("some exchange error"))
+ return mock
+ }
+ return nil
+ }
+ },
+ wantCallback: &callbackResult{
+ err: fmt.Errorf("some exchange error"),
+ },
+ },
+ {
+ name: "success",
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.isTTY = func(fd int) bool { return true }
+ h.useFormPost = true
+ h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
+ return "valid", nil
+ }
+ h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
+ h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
+ mock := mockUpstream(t)
+ mock.EXPECT().
+ ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
+ Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
+ return mock
+ }
+ return nil
+ }
+ },
+ wantCallback: &callbackResult{
+ token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}},
+ },
+ },
+ }
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ h := &handlerState{
+ callbacks: make(chan callbackResult, 1),
+ state: state.State("test-state"),
+ pkce: pkce.Code("test-pkce"),
+ nonce: nonce.Nonce("test-nonce"),
+ }
+ if tt.opt != nil {
+ require.NoError(t, tt.opt(t)(h))
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
+ defer cancel()
+
+ var buf bytes.Buffer
+ h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
+ require.Equal(t,
+ "Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
+ buf.String(),
+ )
+
+ if tt.wantCallback != nil {
+ select {
+ case <-time.After(1 * time.Second):
+ require.Fail(t, "timed out waiting to receive from callbacks channel")
+ case result := <-h.callbacks:
+ require.Equal(t, *tt.wantCallback, result)
+ }
+ }
+ })
+ }
+}
+
func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
+ withFormPostMode := func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.useFormPost = true
+ return nil
+ }
+ }
tests := []struct {
name string
method string
query string
+ body []byte
+ contentType string
opt func(t *testing.T) Option
wantErr string
wantHTTPStatus int
@@ -1342,6 +1640,24 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantErr: "wanted GET",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
+ {
+ name: "wrong method for form_post",
+ method: "GET",
+ query: "",
+ opt: withFormPostMode,
+ wantErr: "wanted POST",
+ wantHTTPStatus: http.StatusMethodNotAllowed,
+ },
+ {
+ name: "invalid form for form_post",
+ method: "POST",
+ query: "",
+ contentType: "application/x-www-form-urlencoded",
+ body: []byte(`%`),
+ opt: withFormPostMode,
+ wantErr: `invalid form: invalid URL escape "%"`,
+ wantHTTPStatus: http.StatusBadRequest,
+ },
{
name: "invalid state",
query: "state=invalid",
@@ -1396,6 +1712,26 @@ func TestHandleAuthCodeCallback(t *testing.T) {
}
},
},
+ {
+ name: "valid form_post",
+ method: http.MethodPost,
+ contentType: "application/x-www-form-urlencoded",
+ body: []byte(`state=test-state&code=valid`),
+ opt: func(t *testing.T) Option {
+ return func(h *handlerState) error {
+ h.useFormPost = true
+ h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
+ h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
+ mock := mockUpstream(t)
+ mock.EXPECT().
+ ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
+ Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
+ return mock
+ }
+ return nil
+ }
+ },
+ },
}
for _, tt := range tests {
tt := tt
@@ -1414,12 +1750,15 @@ func TestHandleAuthCodeCallback(t *testing.T) {
defer cancel()
resp := httptest.NewRecorder()
- req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body))
require.NoError(t, err)
req.URL.RawQuery = tt.query
if tt.method != "" {
req.Method = tt.method
}
+ if tt.contentType != "" {
+ req.Header.Set("Content-Type", tt.contentType)
+ }
err = h.handleAuthCodeCallback(resp, req)
if tt.wantErr != "" {
diff --git a/test/integration/cli_test.go b/test/integration/cli_test.go
index dbba13aa..2e69bc32 100644
--- a/test/integration/cli_test.go
+++ b/test/integration/cli_test.go
@@ -307,16 +307,15 @@ func runPinnipedLoginOIDC(
reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr))
scanner := bufio.NewScanner(reader)
- const prompt = "Please log in: "
for scanner.Scan() {
- line := scanner.Text()
- if strings.HasPrefix(line, prompt) {
- loginURLChan <- strings.TrimPrefix(line, prompt)
+ loginURL, err := url.Parse(strings.TrimSpace(scanner.Text()))
+ if err == nil && loginURL.Scheme == "https" {
+ loginURLChan <- loginURL.String()
return nil
}
}
- return fmt.Errorf("expected stderr to contain %s", prompt)
+ return fmt.Errorf("expected stderr to contain login URL")
})
// Start a background goroutine to read stdout from the CLI and parse out an ExecCredential.
diff --git a/test/integration/e2e_test.go b/test/integration/e2e_test.go
index e4176d32..5102d59d 100644
--- a/test/integration/e2e_test.go
+++ b/test/integration/e2e_test.go
@@ -109,7 +109,7 @@ func TestE2EFullIntegration(t *testing.T) {
})
// Add an OIDC upstream IDP and try using it to authenticate during kubectl commands.
- t.Run("with Supervisor OIDC upstream IDP", func(t *testing.T) {
+ t.Run("with Supervisor OIDC upstream IDP and automatic flow", func(t *testing.T) {
expectedUsername := env.SupervisorUpstreamOIDC.Username
expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups
@@ -195,16 +195,15 @@ func TestE2EFullIntegration(t *testing.T) {
}()
reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe))
- line, err := reader.ReadString('\n')
- if err != nil {
- return fmt.Errorf("could not read login URL line from stderr: %w", err)
+ scanner := bufio.NewScanner(reader)
+ for scanner.Scan() {
+ loginURL, err := url.Parse(strings.TrimSpace(scanner.Text()))
+ if err == nil && loginURL.Scheme == "https" {
+ loginURLChan <- loginURL.String()
+ return nil
+ }
}
- const prompt = "Please log in: "
- if !strings.HasPrefix(line, prompt) {
- return fmt.Errorf("expected %q to have prefix %q", line, prompt)
- }
- loginURLChan <- strings.TrimPrefix(line, prompt)
- return readAndExpectEmpty(reader)
+ return fmt.Errorf("expected stderr to contain login URL")
})
// Start a background goroutine to read stdout from kubectl and return the result as a string.
@@ -242,17 +241,13 @@ func TestE2EFullIntegration(t *testing.T) {
// Expect to be redirected to the upstream provider and log in.
browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC)
- // Expect to be redirected to the localhost callback.
- t.Logf("waiting for redirect to callback")
- browsertest.WaitForURL(t, page, regexp.MustCompile(`\Ahttp://127\.0\.0\.1:[0-9]+/callback\?.+\z`))
+ // Expect to be redirected to the downstream callback which is serving the form_post HTML.
+ t.Logf("waiting for response page %s", downstream.Spec.Issuer)
+ browsertest.WaitForURL(t, page, regexp.MustCompile(regexp.QuoteMeta(downstream.Spec.Issuer)))
- // Wait for the "pre" element that gets rendered for a `text/plain` page, and
- // assert that it contains the success message.
- t.Logf("verifying success page")
- browsertest.WaitForVisibleElements(t, page, "pre")
- msg, err := page.First("pre").Text()
- require.NoError(t, err)
- require.Equal(t, "you have been logged in and may now close this tab", msg)
+ // The response page should have done the background fetch() and POST'ed to the CLI's callback.
+ // It should now be in the "success" state.
+ formpostExpectSuccessState(t, page)
// Expect the CLI to output a list of namespaces in JSON format.
t.Logf("waiting for kubectl to output namespace list JSON")
@@ -275,6 +270,113 @@ func TestE2EFullIntegration(t *testing.T) {
)
})
+ t.Run("with Supervisor OIDC upstream IDP and manual flow", func(t *testing.T) {
+ expectedUsername := env.SupervisorUpstreamOIDC.Username
+ expectedGroups := env.SupervisorUpstreamOIDC.ExpectedGroups
+
+ // Create a ClusterRoleBinding to give our test user from the upstream read-only access to the cluster.
+ testlib.CreateTestClusterRoleBinding(t,
+ rbacv1.Subject{Kind: rbacv1.UserKind, APIGroup: rbacv1.GroupName, Name: expectedUsername},
+ rbacv1.RoleRef{Kind: "ClusterRole", APIGroup: rbacv1.GroupName, Name: "view"},
+ )
+ testlib.WaitForUserToHaveAccess(t, expectedUsername, []string{}, &authorizationv1.ResourceAttributes{
+ Verb: "get",
+ Group: "",
+ Version: "v1",
+ Resource: "namespaces",
+ })
+
+ // Create upstream OIDC provider and wait for it to become ready.
+ testlib.CreateTestOIDCIdentityProvider(t, idpv1alpha1.OIDCIdentityProviderSpec{
+ Issuer: env.SupervisorUpstreamOIDC.Issuer,
+ TLS: &idpv1alpha1.TLSSpec{
+ CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamOIDC.CABundle)),
+ },
+ AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{
+ AdditionalScopes: env.SupervisorUpstreamOIDC.AdditionalScopes,
+ },
+ Claims: idpv1alpha1.OIDCClaims{
+ Username: env.SupervisorUpstreamOIDC.UsernameClaim,
+ Groups: env.SupervisorUpstreamOIDC.GroupsClaim,
+ },
+ Client: idpv1alpha1.OIDCClient{
+ SecretName: testlib.CreateClientCredsSecret(t, env.SupervisorUpstreamOIDC.ClientID, env.SupervisorUpstreamOIDC.ClientSecret).Name,
+ },
+ }, idpv1alpha1.PhaseReady)
+
+ // Use a specific session cache for this test.
+ sessionCachePath := tempDir + "/oidc-test-sessions-manual.yaml"
+ kubeconfigPath := runPinnipedGetKubeconfig(t, env, pinnipedExe, tempDir, []string{
+ "get", "kubeconfig",
+ "--concierge-api-group-suffix", env.APIGroupSuffix,
+ "--concierge-authenticator-type", "jwt",
+ "--concierge-authenticator-name", authenticator.Name,
+ "--oidc-skip-browser",
+ "--oidc-skip-listen",
+ "--oidc-ca-bundle", testCABundlePath,
+ "--oidc-session-cache", sessionCachePath,
+ })
+
+ // Run "kubectl get namespaces" which should trigger a browser login via the plugin.
+ start := time.Now()
+ kubectlCmd := exec.CommandContext(ctx, "kubectl", "get", "namespace", "--kubeconfig", kubeconfigPath)
+ kubectlCmd.Env = append(os.Environ(), env.ProxyEnv()...)
+
+ ptyFile, err := pty.Start(kubectlCmd)
+ require.NoError(t, err)
+
+ // Wait for the subprocess to print the login prompt.
+ t.Logf("waiting for CLI to output login URL and manual prompt")
+ output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ")
+ require.Contains(t, output, "Log in by visiting this link:")
+ require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ")
+
+ // Find the line with the login URL.
+ var loginURL string
+ for _, line := range strings.Split(output, "\n") {
+ trimmed := strings.TrimSpace(line)
+ if strings.HasPrefix(trimmed, "https://") {
+ loginURL = trimmed
+ }
+ }
+ require.NotEmptyf(t, loginURL, "didn't find login URL in output: %s", output)
+
+ t.Logf("navigating to login page")
+ require.NoError(t, page.Navigate(loginURL))
+
+ // Expect to be redirected to the upstream provider and log in.
+ browsertest.LoginToUpstream(t, page, env.SupervisorUpstreamOIDC)
+
+ // Expect to be redirected to the downstream callback which is serving the form_post HTML.
+ t.Logf("waiting for response page %s", downstream.Spec.Issuer)
+ browsertest.WaitForURL(t, page, regexp.MustCompile(regexp.QuoteMeta(downstream.Spec.Issuer)))
+
+ // The response page should have failed to automatically post, and should now be showing the manual instructions.
+ authCode := formpostExpectManualState(t, page)
+
+ // Enter the auth code in the waiting prompt, followed by a newline.
+ t.Logf("'manually' pasting authorization code %q to waiting prompt", authCode)
+ _, err = ptyFile.WriteString(authCode + "\n")
+ require.NoError(t, err)
+
+ // Read all of the remaining output from the subprocess until EOF.
+ t.Logf("waiting for kubectl to output namespace list")
+ remainingOutput, _ := ioutil.ReadAll(ptyFile)
+ // Ignore any errors returned because there is always an error on linux.
+ require.Greaterf(t, len(remainingOutput), 0, "expected to get some more output from the kubectl subcommand, but did not")
+ require.Greaterf(t, len(strings.Split(string(remainingOutput), "\n")), 2, "expected some namespaces to be returned, got %q", string(remainingOutput))
+ t.Logf("first kubectl command took %s", time.Since(start).String())
+
+ requireUserCanUseKubectlWithoutAuthenticatingAgain(ctx, t, env,
+ downstream,
+ kubeconfigPath,
+ sessionCachePath,
+ pinnipedExe,
+ expectedUsername,
+ expectedGroups,
+ )
+ })
+
// Add an LDAP upstream IDP and try using it to authenticate during kubectl commands.
t.Run("with Supervisor LDAP upstream IDP", func(t *testing.T) {
if len(env.ToolsNamespace) == 0 && !env.HasCapability(testlib.CanReachInternetLDAPPorts) {
@@ -376,7 +478,7 @@ func TestE2EFullIntegration(t *testing.T) {
})
}
-func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) {
+func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) string {
readFromFile := ""
testlib.RequireEventuallyWithoutError(t, func() (bool, error) {
@@ -390,6 +492,7 @@ func readFromFileUntilStringIsSeen(t *testing.T, f *os.File, until string) {
}
return false, nil // keep waiting and reading
}, 1*time.Minute, 1*time.Second)
+ return readFromFile
}
func readAvailableOutput(t *testing.T, r io.Reader) (string, bool) {
diff --git a/test/integration/formposthtml_test.go b/test/integration/formposthtml_test.go
new file mode 100644
index 00000000..f44e1ae5
--- /dev/null
+++ b/test/integration/formposthtml_test.go
@@ -0,0 +1,257 @@
+// Copyright 2021 the Pinniped contributors. All Rights Reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+package integration
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "regexp"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/ory/fosite"
+ "github.com/ory/fosite/token/hmac"
+ "github.com/sclevine/agouti"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "go.pinniped.dev/internal/httputil/securityheader"
+ "go.pinniped.dev/internal/oidc/provider/formposthtml"
+ "go.pinniped.dev/test/testlib"
+ "go.pinniped.dev/test/testlib/browsertest"
+)
+
+func TestFormPostHTML(t *testing.T) {
+ // Run a mock callback handler, simulating the one running in the CLI.
+ callbackURL, expectCallback := formpostCallbackServer(t)
+
+ // Open a single browser for all subtests to use (in sequence).
+ page := browsertest.Open(t)
+
+ t.Run("success", func(t *testing.T) {
+ // Serve the form_post template with successful parameters.
+ responseParams := formpostRandomParams(t)
+ formpostInitiate(t, page, formpostTemplateServer(t, callbackURL, responseParams))
+
+ // Now we handle the callback and assert that we got what we expected. This should transition
+ // the UI into the success state.
+ expectCallback(t, responseParams)
+ formpostExpectSuccessState(t, page)
+ })
+
+ t.Run("callback server error", func(t *testing.T) {
+ // Serve the form_post template with a redirect URI that will return an HTTP 500 response.
+ responseParams := formpostRandomParams(t)
+ formpostInitiate(t, page, formpostTemplateServer(t, callbackURL+"?fail=500", responseParams))
+
+ // Now we handle the callback and assert that we got what we expected.
+ expectCallback(t, responseParams)
+
+ // This is not 100% the behavior we'd like, but because our JS is making
+ // a cross-origin fetch() without CORS, we don't get to know anything
+ // about the response (even whether it is 200 vs. 500), so this case
+ // is the same as the success case.
+ //
+ // This case is fairly unlikely in practice, and if the CLI encounters
+ // an error it can also expose it via stderr anyway.
+ formpostExpectSuccessState(t, page)
+ })
+
+ t.Run("network failure", func(t *testing.T) {
+ // Serve the form_post template with a redirect URI that will return a network error.
+ responseParams := formpostRandomParams(t)
+ formpostInitiate(t, page, formpostTemplateServer(t, callbackURL+"?fail=close", responseParams))
+
+ // Now we handle the callback and assert that we got what we expected.
+ // This will trigger the callback server to close the client connection abruptly because
+ // of the `?fail=close` parameter above.
+ expectCallback(t, responseParams)
+
+ // This failure should cause the UI to enter the "manual" state.
+ actualCode := formpostExpectManualState(t, page)
+ require.Equal(t, responseParams.Get("code"), actualCode)
+ })
+
+ t.Run("timeout", func(t *testing.T) {
+ // Serve the form_post template with successful parameters.
+ responseParams := formpostRandomParams(t)
+ formpostInitiate(t, page, formpostTemplateServer(t, callbackURL, responseParams))
+
+ // Sleep for longer than the two second timeout.
+ // During this sleep we are blocking the callback from returning.
+ time.Sleep(3 * time.Second)
+
+ // Assert that the timeout fires and we see the manual instructions.
+ actualCode := formpostExpectManualState(t, page)
+ require.Equal(t, responseParams.Get("code"), actualCode)
+
+ // Now simulate the callback finally succeeding, in which case
+ // the manual instructions should disappear and we should see the success
+ // div instead.
+ expectCallback(t, responseParams)
+ formpostExpectSuccessState(t, page)
+ })
+}
+
+// formpostCallbackServer runs a test server that simulates the CLI's callback handler.
+// It returns the URL of the running test server and a function for fetching the next
+// received form POST parameters.
+//
+// The test server supports special `?fail=close` and `?fail=500` to force error cases.
+func formpostCallbackServer(t *testing.T) (string, func(*testing.T, url.Values)) {
+ results := make(chan url.Values)
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.NoError(t, r.ParseForm())
+
+ // Extract only the POST parameters (r.Form also contains URL query parameters).
+ postParams := url.Values{}
+ for k := range r.Form {
+ if v := r.PostFormValue(k); v != "" {
+ postParams.Set(k, v)
+ }
+ }
+
+ // Send the form parameters back on the results channel, giving up if the
+ // request context is cancelled (such as if the client disconnects).
+ select {
+ case results <- postParams:
+ case <-r.Context().Done():
+ return
+ }
+
+ switch r.URL.Query().Get("fail") {
+ case "close": // If "fail=close" is passed, close the connection immediately.
+ if conn, _, err := w.(http.Hijacker).Hijack(); err == nil {
+ _ = conn.Close()
+ }
+ return
+ case "500": // If "fail=500" is passed, return a 500 error.
+ w.WriteHeader(http.StatusInternalServerError)
+ return
+ }
+ }))
+ t.Cleanup(func() {
+ close(results)
+ server.Close()
+ })
+ return server.URL, func(t *testing.T, expected url.Values) {
+ t.Logf("expecting to get a POST callback...")
+ select {
+ case actual := <-results:
+ require.Equal(t, expected, actual, "did not receive expected callback")
+ case <-time.After(3 * time.Second):
+ t.Errorf("failed to receive expected callback %v", expected)
+ t.FailNow()
+ }
+ }
+}
+
+// formpostTemplateServer runs a test server that serves formposthtml.Template() rendered with test parameters.
+func formpostTemplateServer(t *testing.T, redirectURI string, responseParams url.Values) string {
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fosite.WriteAuthorizeFormPostResponse(redirectURI, responseParams, formposthtml.Template(), w)
+ })
+ server := httptest.NewServer(securityheader.WrapWithCustomCSP(
+ handler,
+ formposthtml.ContentSecurityPolicy(),
+ ))
+ t.Cleanup(server.Close)
+ return server.URL
+}
+
+// formpostRandomParams is a helper to generate random OAuth2 response parameters for testing.
+func formpostRandomParams(t *testing.T) url.Values {
+ generator := &hmac.HMACStrategy{GlobalSecret: testlib.RandBytes(t, 32), TokenEntropy: 32}
+ authCode, _, err := generator.Generate()
+ require.NoError(t, err)
+ return url.Values{
+ "code": []string{authCode},
+ "scope": []string{"openid offline_access pinniped:request-audience"},
+ "state": []string{testlib.RandHex(t, 16)},
+ }
+}
+
+// formpostExpectTitle asserts that the page has the expected title.
+func formpostExpectTitle(t *testing.T, page *agouti.Page, expected string) {
+ actual, err := page.Title()
+ require.NoError(t, err)
+ require.Equal(t, expected, actual)
+}
+
+// formpostExpectTitle asserts that the page has the expected SVG/emoji favicon.
+func formpostExpectFavicon(t *testing.T, page *agouti.Page, expected string) {
+ iconURL, err := page.First("#favicon").Attribute("href")
+ require.NoError(t, err)
+ require.True(t, strings.HasPrefix(iconURL, "data:image/svg+xml,