From 5029495fdbfc4a77883614a3e7c4ec6a536cf280 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 8 Jul 2021 14:32:44 -0500 Subject: [PATCH] Add manual paste flow to `pinniped login oidc` command. This adds a new login flow that allows manually pasting the authorization code instead of receiving a browser-based callback. Signed-off-by: Matt Moyer --- pkg/oidcclient/login.go | 106 ++++++++++++---- pkg/oidcclient/login_test.go | 226 ++++++++++++++++++++++++++++++++--- test/integration/cli_test.go | 9 +- test/integration/e2e_test.go | 17 ++- 4 files changed, 302 insertions(+), 56 deletions(-) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index baa9ac19..7bfec416 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "mime" "net" "net/http" @@ -98,6 +99,8 @@ type handlerState struct { generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error + listen func(string, string) (net.Listener, error) + isTTY func(int) bool getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) promptForValue func(ctx context.Context, promptLabel string) (string, error) @@ -264,6 +267,8 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er generateNonce: nonce.Generate, generatePKCE: pkce.Generate, openURL: browser.OpenURL, + listen: net.Listen, + isTTY: term.IsTerminal, getProvider: upstreamoidc.New, validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) { return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token) @@ -489,16 +494,27 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) ( // Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint. // Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error. func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { - // Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number). - listener, err := net.Listen("tcp", h.listenAddr) + // Attempt to open a local TCP listener, logging but otherwise ignoring any error. + listener, err := h.listen("tcp", h.listenAddr) if err != nil { - return nil, fmt.Errorf("could not open callback listener: %w", err) + h.logger.V(debugLogLevel).Error(err, "could not open callback listener") } - h.oauth2Config.RedirectURL = (&url.URL{ - Scheme: "http", - Host: listener.Addr().String(), - Path: h.callbackPath, - }).String() + + // If the listener failed to start and stdin is not a TTY, then we have no hope of succeeding, + // since we won't be able to receive the web callback and we can't prompt for the manual auth code. + if listener == nil && !h.isTTY(syscall.Stdin) { + return nil, fmt.Errorf("login failed: must have either a localhost listener or stdin must be a TTY") + } + + // Update the OAuth2 redirect_uri to match the actual listener address (if there is one), or just use + // a fake ":0" port if there is no listener running. + redirectURI := url.URL{Scheme: "http", Path: h.callbackPath} + if listener == nil { + redirectURI.Host = "127.0.0.1:0" + } else { + redirectURI.Host = listener.Addr().String() + } + h.oauth2Config.RedirectURL = redirectURI.String() // If the server supports it, request response_mode=form_post. authParams := *authorizeOptions @@ -509,16 +525,24 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp // Now that we have a redirect URL with the listener port, we can build the authorize URL. authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...) - // Start a callback server in a background goroutine. - shutdown := h.serve(listener) - defer shutdown() - - // Open the authorize URL in the users browser. - if err := h.openURL(authorizeURL); err != nil { - return nil, fmt.Errorf("could not open browser: %w", err) + // If there is a listener running, start serving the callback handler in a background goroutine. + if listener != nil { + shutdown := h.serve(listener) + defer shutdown() } - // Wait for either the callback or a timeout. + // Open the authorize URL in the users browser, logging but otherwise ignoring any error. + if err := h.openURL(authorizeURL); err != nil { + h.logger.V(debugLogLevel).Error(err, "could not open browser") + } + + ctx, cancel := context.WithCancel(h.ctx) + defer cancel() + + // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). + h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + + // Wait for either the web callback, a pasted auth code, or a timeout. select { case <-h.ctx.Done(): return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err()) @@ -530,6 +554,36 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp } } +func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) { + _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) + + // If stdin is not a TTY, print the URL but don't prompt for the manual paste, + // since we have no way of reading it. + if !h.isTTY(syscall.Stdin) { + return + } + + // If the server didn't support response_mode=form_post, don't bother prompting for the manual + // code because the user isn't going to have any easy way to manually copy it anyway. + if !h.useFormPost { + return + } + + // Launch the manual auth code prompt in a background goroutine, which will be cancelled + // if the parent context is cancelled (when the login succeeds or times out). + go func() { + code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ") + if err != nil { + h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} + return + } + + // When a code is pasted, redeem it for a token and return that result on the callbacks channel. + token, err := h.redeemAuthCode(ctx, code) + h.callbacks <- callbackResult{token: token, err: err} + }() +} + func promptForValue(ctx context.Context, promptLabel string) (string, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { return "", errors.New("stdin is not connected to a terminal") @@ -758,14 +812,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). - ExchangeAuthcodeAndValidateTokens( - r.Context(), - params.Get("code"), - h.pkce, - h.nonce, - h.oauth2Config.RedirectURL, - ) + token, err := h.redeemAuthCode(r.Context(), params.Get("code")) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } @@ -775,6 +822,17 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return nil } +func (h *handlerState) redeemAuthCode(ctx context.Context, code string) (*oidctypes.Token, error) { + return h.getProvider(h.oauth2Config, h.provider, h.httpClient). + ExchangeAuthcodeAndValidateTokens( + ctx, + code, + h.pkce, + h.nonce, + h.oauth2Config.RedirectURL, + ) +} + func (h *handlerState) serve(listener net.Listener) func() { mux := http.NewServeMux() mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 907935ef..87587ecf 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -10,10 +10,12 @@ import ( "errors" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "net/url" "strings" + "syscall" "testing" "time" @@ -490,38 +492,94 @@ func TestLogin(t *testing.T) { // nolint:gocyclo }) h.cache = cache - h.listenAddr = "invalid-listen-address" - + h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } + h.isTTY = func(int) bool { return false } return nil } }, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"", - "\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""}, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`, + `"level"=4 "msg"="Pinniped: Refreshing cached token."`, + `"msg"="could not open callback listener" "error"="some listen error"`, + }, // Expect this to fall through to the authorization code flow, so it fails here. - wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", + wantErr: "login failed: must have either a localhost listener or stdin must be a TTY", }, { - name: "listen failure", + name: "listen failure and non-tty stdin", opt: func(t *testing.T) Option { return func(h *handlerState) error { - h.listenAddr = "invalid-listen-address" + h.listen = func(net string, addr string) (net.Listener, error) { + assert.Equal(t, "tcp", net) + assert.Equal(t, "localhost:0", addr) + return nil, fmt.Errorf("some listen error") + } + h.isTTY = func(fd int) bool { + assert.Equal(t, fd, syscall.Stdin) + return false + } return nil } }, - issuer: successServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", + issuer: successServer.URL, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`, + `"msg"="could not open callback listener" "error"="some listen error"`, + }, + wantErr: "login failed: must have either a localhost listener or stdin must be a TTY", }, { - name: "browser open failure", + name: "listen failure and manual prompt fails", opt: func(t *testing.T) Option { - return WithBrowserOpen(func(url string) error { - return fmt.Errorf("some browser open error") - }) + return func(h *handlerState) error { + h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } + h.isTTY = func(fd int) bool { return true } + h.openURL = func(authorizeURL string) error { + parsed, err := url.Parse(authorizeURL) + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri")) + require.Equal(t, "form_post", parsed.Query().Get("response_mode")) + return fmt.Errorf("some browser open error") + } + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "", fmt.Errorf("some prompt error") + } + return nil + } }, - issuer: successServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "could not open browser: some browser open error", + issuer: formPostSuccessServer.URL, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, + `"msg"="could not open callback listener" "error"="some listen error"`, + `"msg"="could not open browser" "error"="some browser open error"`, + }, + wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", + }, + { + name: "listen success and manual prompt succeeds", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } + h.isTTY = func(fd int) bool { return true } + h.openURL = func(authorizeURL string) error { + parsed, err := url.Parse(authorizeURL) + require.NoError(t, err) + require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri")) + require.Equal(t, "form_post", parsed.Query().Get("response_mode")) + return nil + } + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "", fmt.Errorf("some prompt error") + } + return nil + } + }, + issuer: formPostSuccessServer.URL, + wantLogs: []string{ + `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, + `"msg"="could not open callback listener" "error"="some listen error"`, + }, + wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", }, { name: "timeout waiting for callback", @@ -1388,10 +1446,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo WithContext(context.Background()), WithListenPort(0), WithScopes([]string{"test-scope"}), + WithSkipBrowserOpen(), tt.opt(t), WithLogger(testLogger), ) - require.Equal(t, tt.wantLogs, testLogger.Lines()) + testLogger.Expect(tt.wantLogs) if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) require.Nil(t, tok) @@ -1425,6 +1484,137 @@ func TestLogin(t *testing.T) { // nolint:gocyclo } } +func TestHandlePasteCallback(t *testing.T) { + const testRedirectURI = "http://127.0.0.1:12324/callback" + + tests := []struct { + name string + opt func(t *testing.T) Option + wantCallback *callbackResult + }{ + { + name: "no stdin available", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { + require.Equal(t, syscall.Stdin, fd) + return false + } + h.useFormPost = true + return nil + } + }, + }, + { + name: "no form_post mode available", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = false + return nil + } + }, + }, + { + name: "prompt fails", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = true + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel) + return "", fmt.Errorf("some prompt error") + } + return nil + } + }, + wantCallback: &callbackResult{ + err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"), + }, + }, + { + name: "redeeming code fails", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = true + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "invalid", nil + } + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(nil, fmt.Errorf("some exchange error")) + return mock + } + return nil + } + }, + wantCallback: &callbackResult{ + err: fmt.Errorf("some exchange error"), + }, + }, + { + name: "success", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.isTTY = func(fd int) bool { return true } + h.useFormPost = true + h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + return "valid", nil + } + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil) + return mock + } + return nil + } + }, + wantCallback: &callbackResult{ + token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + h := &handlerState{ + callbacks: make(chan callbackResult, 1), + state: state.State("test-state"), + pkce: pkce.Code("test-pkce"), + nonce: nonce.Nonce("test-nonce"), + } + if tt.opt != nil { + require.NoError(t, tt.opt(t)(h)) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var buf bytes.Buffer + h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf) + require.Equal(t, + "Log in by visiting this link:\n\n https://test-authorize-url/\n\n", + buf.String(), + ) + + if tt.wantCallback != nil { + select { + case <-time.After(1 * time.Second): + require.Fail(t, "timed out waiting to receive from callbacks channel") + case result := <-h.callbacks: + require.Equal(t, *tt.wantCallback, result) + } + } + }) + } +} + func TestHandleAuthCodeCallback(t *testing.T) { const testRedirectURI = "http://127.0.0.1:12324/callback" diff --git a/test/integration/cli_test.go b/test/integration/cli_test.go index dbba13aa..2e69bc32 100644 --- a/test/integration/cli_test.go +++ b/test/integration/cli_test.go @@ -307,16 +307,15 @@ func runPinnipedLoginOIDC( reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr)) scanner := bufio.NewScanner(reader) - const prompt = "Please log in: " for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, prompt) { - loginURLChan <- strings.TrimPrefix(line, prompt) + loginURL, err := url.Parse(strings.TrimSpace(scanner.Text())) + if err == nil && loginURL.Scheme == "https" { + loginURLChan <- loginURL.String() return nil } } - return fmt.Errorf("expected stderr to contain %s", prompt) + return fmt.Errorf("expected stderr to contain login URL") }) // Start a background goroutine to read stdout from the CLI and parse out an ExecCredential. diff --git a/test/integration/e2e_test.go b/test/integration/e2e_test.go index 57cc2a6f..768210a7 100644 --- a/test/integration/e2e_test.go +++ b/test/integration/e2e_test.go @@ -195,16 +195,15 @@ func TestE2EFullIntegration(t *testing.T) { }() reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe)) - line, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("could not read login URL line from stderr: %w", err) + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + loginURL, err := url.Parse(strings.TrimSpace(scanner.Text())) + if err == nil && loginURL.Scheme == "https" { + loginURLChan <- loginURL.String() + return nil + } } - const prompt = "Please log in: " - if !strings.HasPrefix(line, prompt) { - return fmt.Errorf("expected %q to have prefix %q", line, prompt) - } - loginURLChan <- strings.TrimPrefix(line, prompt) - return readAndExpectEmpty(reader) + return fmt.Errorf("expected stderr to contain login URL") }) // Start a background goroutine to read stdout from kubectl and return the result as a string.