From 1e32530d7bc81744a55c84d3dda4e05266b13022 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Thu, 29 Jul 2021 17:49:16 -0500 Subject: [PATCH] Fix broken TTY after manual auth code prompt. This may be a temporary fix. It switches the manual auth code prompt to use `promptForValue()` instead of `promptForSecret()`. The `promptForSecret()` function no longer supports cancellation (the v0.9.2 behavior) and the method of cancelling in `promptForValue()` is now based on running the blocking read in a background goroutine, which is allowed to block forever or leak (which is not important for our CLI use case). This means that the authorization code is now visible in the user's terminal, but this is really not a big deal because of PKCE and the limited lifetime of an auth code. The main goroutine now correctly waits for the "manual prompt" goroutine to clean up, which now includes printing the extra newline that would normally have been entered by the user in the manual flow. The text of the manual login prompt is updated to be more concise and less scary (don't use the word "fail"). Signed-off-by: Matt Moyer --- pkg/oidcclient/login.go | 95 ++++++++++++++++++++---------------- pkg/oidcclient/login_test.go | 20 ++++---- test/integration/e2e_test.go | 4 +- 3 files changed, 66 insertions(+), 53 deletions(-) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index c649f4fe..a1b5e0b6 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -18,6 +18,7 @@ import ( "os" "sort" "strings" + "sync" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -112,7 +113,7 @@ type handlerState struct { getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) promptForValue func(ctx context.Context, promptLabel string) (string, error) - promptForSecret func(ctx context.Context, promptLabel string) (string, error) + promptForSecret func(promptLabel string) (string, error) callbacks chan callbackResult } @@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er callbackPath: "/callback", ctx: context.Background(), logger: logr.Discard(), // discard logs unless a logger is specified - callbacks: make(chan callbackResult), + callbacks: make(chan callbackResult, 2), httpClient: http.DefaultClient, // Default implementations of external dependencies (to be mocked in tests). @@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) { password := h.getEnv(defaultPasswordEnvVarName) if password == "" { - password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt) + password, err = h.promptForSecret(defaultLDAPPasswordPrompt) if err != nil { return "", "", fmt.Errorf("error prompting for password: %w", err) } @@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp h.logger.V(debugLogLevel).Error(err, "could not open browser") } - ctx, cancel := context.WithCancel(h.ctx) - defer cancel() - // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). - h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + ctx, cancel := context.WithCancel(h.ctx) + cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + defer func() { + cancel() + cleanupPrompt() + }() // Wait for either the web callback, a pasted auth code, or a timeout. select { @@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp } } -func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) { +func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() { _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) // If stdin is not a TTY, print the URL but don't prompt for the manual paste, // since we have no way of reading it. if !h.isTTY(stdin()) { - return + return func() {} } // If the server didn't support response_mode=form_post, don't bother prompting for the manual // code because the user isn't going to have any easy way to manually copy it anyway. if !h.useFormPost { - return + return func() {} } // Launch the manual auth code prompt in a background goroutine, which will be cancelled // if the parent context is cancelled (when the login succeeds or times out). + var wg sync.WaitGroup + wg.Add(1) go func() { - code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ") + defer func() { + // Always emit a newline so the kubectl output is visually separated from the login prompts. + _, _ = fmt.Fprintln(os.Stderr) + + wg.Done() + }() + code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ") if err != nil { + // Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing + // newline that simulates the user having pressed "enter". + _, _ = fmt.Fprint(os.Stderr, "[...]\n") + h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} return } @@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin token, err := h.redeemAuthCode(ctx, code) h.callbacks <- callbackResult{token: token, err: err} }() + return wg.Wait } func promptForValue(ctx context.Context, promptLabel string) (string, error) { @@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } - // If the context is canceled, set the read deadline on stdin so the read immediately finishes. - ctx, cancel := context.WithCancel(ctx) - defer cancel() + type readResult struct { + text string + err error + } + readResults := make(chan readResult) go func() { - <-ctx.Done() - _ = os.Stdin.SetReadDeadline(time.Now()) + text, err := bufio.NewReader(os.Stdin).ReadString('\n') + readResults <- readResult{text, err} + close(readResults) }() - text, err := bufio.NewReader(os.Stdin).ReadString('\n') - if err != nil { - return "", fmt.Errorf("could read input from stdin: %w", err) + // If the context is canceled, return immediately. The ReadString() operation will stay hung in the background + // goroutine indefinitely. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + select { + case <-ctx.Done(): + return "", ctx.Err() + case r := <-readResults: + return strings.TrimSpace(r.text), r.err } - text = strings.TrimSpace(text) - return text, nil } -func promptForSecret(ctx context.Context, promptLabel string) (string, error) { +func promptForSecret(promptLabel string) (string, error) { if !term.IsTerminal(stdin()) { return "", errors.New("stdin is not connected to a terminal") } @@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) { if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } - - // If the context is canceled, set the read deadline on stdin so the read immediately finishes. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - <-ctx.Done() - _ = os.Stdin.SetReadDeadline(time.Now()) - - // term.ReadPassword swallows the newline that was typed by the user, so to - // avoid the next line of output from happening on same line as the password - // prompt, we need to print a newline. - // - // Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next - // on stderr is formatted correctly. - _, _ = fmt.Fprint(os.Stderr, "\n") - }() - password, err := term.ReadPassword(stdin()) if err != nil { return "", fmt.Errorf("could not read password: %w", err) } + // term.ReadPassword swallows the newline that was typed by the user, so to + // avoid the next line of output from happening on same line as the password + // prompt, we need to print a newline. + _, err = fmt.Fprint(os.Stderr, "\n") + if err != nil { + return "", fmt.Errorf("could not print newline to stderr: %w", err) + } return string(password), err } @@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() { } go func() { _ = srv.Serve(listener) }() return func() { - // Gracefully shut down the server, allowing up to 5 00ms for + // Gracefully shut down the server, allowing up to 100ms for // clients to receive any in-flight responses. - shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond) + shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond) _ = srv.Shutdown(shutdownCtx) cancel() } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 4edd1e9a..94dc24c4 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } - h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil } + h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil } cache := &mockSessionCache{t: t, getReturnsToken: nil} cacheKey := SessionCacheKey{ @@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.Equal(t, "form_post", parsed.Query().Get("response_mode")) return fmt.Errorf("some browser open error") } - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "", fmt.Errorf("some prompt error") } return nil @@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.Equal(t, "form_post", parsed.Query().Get("response_mode")) return nil } - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "", fmt.Errorf("some prompt error") } return nil @@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") } + h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") } return nil } }, @@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.Equal(t, "Username: ", promptLabel) return "some-upstream-username", nil } - h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string) (string, error) { require.Equal(t, "Password: ", promptLabel) return "some-upstream-password", nil } @@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } - h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string) (string, error) { require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } @@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { - assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel) + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel) return "", fmt.Errorf("some prompt error") } return nil @@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "invalid", nil } h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} @@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "valid", nil } h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} diff --git a/test/integration/e2e_test.go b/test/integration/e2e_test.go index 2086de3e..2a67726e 100644 --- a/test/integration/e2e_test.go +++ b/test/integration/e2e_test.go @@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) { // Wait for the subprocess to print the login prompt. t.Logf("waiting for CLI to output login URL and manual prompt") - output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ") + output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ") require.Contains(t, output, "Log in by visiting this link:") - require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ") + require.Contains(t, output, "Optionally, paste your authorization code: ") // Find the line with the login URL. var loginURL string