Fix broken TTY after manual auth code prompt.

This may be a temporary fix. It switches the manual auth code prompt to use `promptForValue()` instead of `promptForSecret()`. The `promptForSecret()` function no longer supports cancellation (the v0.9.2 behavior) and the method of cancelling in `promptForValue()` is now based on running the blocking read in a background goroutine, which is allowed to block forever or leak (which is not important for our CLI use case).

This means that the authorization code is now visible in the user's terminal, but this is really not a big deal because of PKCE and the limited lifetime of an auth code.

The main goroutine now correctly waits for the "manual prompt" goroutine to clean up, which now includes printing the extra newline that would normally have been entered by the user in the manual flow.

The text of the manual login prompt is updated to be more concise and less scary (don't use the word "fail").

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2021-07-29 17:49:16 -05:00
parent 0ab8e14e4a
commit 1e32530d7b
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
3 changed files with 66 additions and 53 deletions

View File

@ -18,6 +18,7 @@ import (
"os"
"sort"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
@ -112,7 +113,7 @@ type handlerState struct {
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error)
callbacks chan callbackResult
}
@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
callbackPath: "/callback",
ctx: context.Background(),
logger: logr.Discard(), // discard logs unless a logger is specified
callbacks: make(chan callbackResult),
callbacks: make(chan callbackResult, 2),
httpClient: http.DefaultClient,
// Default implementations of external dependencies (to be mocked in tests).
@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
password := h.getEnv(defaultPasswordEnvVarName)
if password == "" {
password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
password, err = h.promptForSecret(defaultLDAPPasswordPrompt)
if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err)
}
@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
h.logger.V(debugLogLevel).Error(err, "could not open browser")
}
ctx, cancel := context.WithCancel(h.ctx)
defer cancel()
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
defer func() {
cancel()
cleanupPrompt()
}()
// Wait for either the web callback, a pasted auth code, or a timeout.
select {
@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
}
}
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it.
if !h.isTTY(stdin()) {
return
return func() {}
}
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
// code because the user isn't going to have any easy way to manually copy it anyway.
if !h.useFormPost {
return
return func() {}
}
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
// if the parent context is cancelled (when the login succeeds or times out).
var wg sync.WaitGroup
wg.Add(1)
go func() {
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr)
wg.Done()
}()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return
}
@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
token, err := h.redeemAuthCode(ctx, code)
h.callbacks <- callbackResult{token: token, err: err}
}()
return wg.Wait
}
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
type readResult struct {
text string
err error
}
readResults := make(chan readResult)
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
readResults <- readResult{text, err}
close(readResults)
}()
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", fmt.Errorf("could read input from stdin: %w", err)
// If the context is canceled, return immediately. The ReadString() operation will stay hung in the background
// goroutine indefinitely.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
select {
case <-ctx.Done():
return "", ctx.Err()
case r := <-readResults:
return strings.TrimSpace(r.text), r.err
}
text = strings.TrimSpace(text)
return text, nil
}
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
func promptForSecret(promptLabel string) (string, error) {
if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal")
}
@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
//
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
// on stderr is formatted correctly.
_, _ = fmt.Fprint(os.Stderr, "\n")
}()
password, err := term.ReadPassword(stdin())
if err != nil {
return "", fmt.Errorf("could not read password: %w", err)
}
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n")
if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err)
}
return string(password), err
}
@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() {
}
go func() { _ = srv.Serve(listener) }()
return func() {
// Gracefully shut down the server, allowing up to 5 00ms for
// Gracefully shut down the server, allowing up to 100ms for
// clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond)
shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond)
_ = srv.Shutdown(shutdownCtx)
cancel()
}

View File

@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error")
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}
@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
return nil
@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}

View File

@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) {
// Wait for the subprocess to print the login prompt.
t.Logf("waiting for CLI to output login URL and manual prompt")
output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ")
output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ")
require.Contains(t, output, "Log in by visiting this link:")
require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ")
require.Contains(t, output, "Optionally, paste your authorization code: ")
// Find the line with the login URL.
var loginURL string