diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index c649f4fe..a1b5e0b6 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -18,6 +18,7 @@ import ( "os" "sort" "strings" + "sync" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -112,7 +113,7 @@ type handlerState struct { getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) promptForValue func(ctx context.Context, promptLabel string) (string, error) - promptForSecret func(ctx context.Context, promptLabel string) (string, error) + promptForSecret func(promptLabel string) (string, error) callbacks chan callbackResult } @@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er callbackPath: "/callback", ctx: context.Background(), logger: logr.Discard(), // discard logs unless a logger is specified - callbacks: make(chan callbackResult), + callbacks: make(chan callbackResult, 2), httpClient: http.DefaultClient, // Default implementations of external dependencies (to be mocked in tests). @@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) { password := h.getEnv(defaultPasswordEnvVarName) if password == "" { - password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt) + password, err = h.promptForSecret(defaultLDAPPasswordPrompt) if err != nil { return "", "", fmt.Errorf("error prompting for password: %w", err) } @@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp h.logger.V(debugLogLevel).Error(err, "could not open browser") } - ctx, cancel := context.WithCancel(h.ctx) - defer cancel() - // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). - h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + ctx, cancel := context.WithCancel(h.ctx) + cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + defer func() { + cancel() + cleanupPrompt() + }() // Wait for either the web callback, a pasted auth code, or a timeout. select { @@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp } } -func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) { +func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() { _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) // If stdin is not a TTY, print the URL but don't prompt for the manual paste, // since we have no way of reading it. if !h.isTTY(stdin()) { - return + return func() {} } // If the server didn't support response_mode=form_post, don't bother prompting for the manual // code because the user isn't going to have any easy way to manually copy it anyway. if !h.useFormPost { - return + return func() {} } // Launch the manual auth code prompt in a background goroutine, which will be cancelled // if the parent context is cancelled (when the login succeeds or times out). + var wg sync.WaitGroup + wg.Add(1) go func() { - code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ") + defer func() { + // Always emit a newline so the kubectl output is visually separated from the login prompts. + _, _ = fmt.Fprintln(os.Stderr) + + wg.Done() + }() + code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ") if err != nil { + // Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing + // newline that simulates the user having pressed "enter". + _, _ = fmt.Fprint(os.Stderr, "[...]\n") + h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} return } @@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin token, err := h.redeemAuthCode(ctx, code) h.callbacks <- callbackResult{token: token, err: err} }() + return wg.Wait } func promptForValue(ctx context.Context, promptLabel string) (string, error) { @@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } - // If the context is canceled, set the read deadline on stdin so the read immediately finishes. - ctx, cancel := context.WithCancel(ctx) - defer cancel() + type readResult struct { + text string + err error + } + readResults := make(chan readResult) go func() { - <-ctx.Done() - _ = os.Stdin.SetReadDeadline(time.Now()) + text, err := bufio.NewReader(os.Stdin).ReadString('\n') + readResults <- readResult{text, err} + close(readResults) }() - text, err := bufio.NewReader(os.Stdin).ReadString('\n') - if err != nil { - return "", fmt.Errorf("could read input from stdin: %w", err) + // If the context is canceled, return immediately. The ReadString() operation will stay hung in the background + // goroutine indefinitely. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + select { + case <-ctx.Done(): + return "", ctx.Err() + case r := <-readResults: + return strings.TrimSpace(r.text), r.err } - text = strings.TrimSpace(text) - return text, nil } -func promptForSecret(ctx context.Context, promptLabel string) (string, error) { +func promptForSecret(promptLabel string) (string, error) { if !term.IsTerminal(stdin()) { return "", errors.New("stdin is not connected to a terminal") } @@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) { if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } - - // If the context is canceled, set the read deadline on stdin so the read immediately finishes. - ctx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - <-ctx.Done() - _ = os.Stdin.SetReadDeadline(time.Now()) - - // term.ReadPassword swallows the newline that was typed by the user, so to - // avoid the next line of output from happening on same line as the password - // prompt, we need to print a newline. - // - // Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next - // on stderr is formatted correctly. - _, _ = fmt.Fprint(os.Stderr, "\n") - }() - password, err := term.ReadPassword(stdin()) if err != nil { return "", fmt.Errorf("could not read password: %w", err) } + // term.ReadPassword swallows the newline that was typed by the user, so to + // avoid the next line of output from happening on same line as the password + // prompt, we need to print a newline. + _, err = fmt.Fprint(os.Stderr, "\n") + if err != nil { + return "", fmt.Errorf("could not print newline to stderr: %w", err) + } return string(password), err } @@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() { } go func() { _ = srv.Serve(listener) }() return func() { - // Gracefully shut down the server, allowing up to 5 00ms for + // Gracefully shut down the server, allowing up to 100ms for // clients to receive any in-flight responses. - shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond) + shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond) _ = srv.Shutdown(shutdownCtx) cancel() } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 4edd1e9a..94dc24c4 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } - h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil } + h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil } cache := &mockSessionCache{t: t, getReturnsToken: nil} cacheKey := SessionCacheKey{ @@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.Equal(t, "form_post", parsed.Query().Get("response_mode")) return fmt.Errorf("some browser open error") } - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "", fmt.Errorf("some prompt error") } return nil @@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.Equal(t, "form_post", parsed.Query().Get("response_mode")) return nil } - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "", fmt.Errorf("some prompt error") } return nil @@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") } + h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") } return nil } }, @@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.Equal(t, "Username: ", promptLabel) return "some-upstream-username", nil } - h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string) (string, error) { require.Equal(t, "Password: ", promptLabel) return "some-upstream-password", nil } @@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } - h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string) (string, error) { require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } @@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { - assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel) + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel) return "", fmt.Errorf("some prompt error") } return nil @@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "invalid", nil } h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} @@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "valid", nil } h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} diff --git a/test/integration/e2e_test.go b/test/integration/e2e_test.go index 2086de3e..2a67726e 100644 --- a/test/integration/e2e_test.go +++ b/test/integration/e2e_test.go @@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) { // Wait for the subprocess to print the login prompt. t.Logf("waiting for CLI to output login URL and manual prompt") - output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ") + output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ") require.Contains(t, output, "Log in by visiting this link:") - require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ") + require.Contains(t, output, "Optionally, paste your authorization code: ") // Find the line with the login URL. var loginURL string