Merge pull request #758 from mattmoyer/use-plain-authcode-prompt

Fix broken TTY after manual auth code prompt.
This commit is contained in:
Matt Moyer 2021-07-30 13:50:27 -05:00 committed by GitHub
commit f4badb3961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 53 deletions

View File

@ -18,6 +18,7 @@ import (
"os" "os"
"sort" "sort"
"strings" "strings"
"sync"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
@ -112,7 +113,7 @@ type handlerState struct {
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error) promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(ctx context.Context, promptLabel string) (string, error) promptForSecret func(promptLabel string) (string, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
callbackPath: "/callback", callbackPath: "/callback",
ctx: context.Background(), ctx: context.Background(),
logger: logr.Discard(), // discard logs unless a logger is specified logger: logr.Discard(), // discard logs unless a logger is specified
callbacks: make(chan callbackResult), callbacks: make(chan callbackResult, 2),
httpClient: http.DefaultClient, httpClient: http.DefaultClient,
// Default implementations of external dependencies (to be mocked in tests). // Default implementations of external dependencies (to be mocked in tests).
@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
password := h.getEnv(defaultPasswordEnvVarName) password := h.getEnv(defaultPasswordEnvVarName)
if password == "" { if password == "" {
password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt) password, err = h.promptForSecret(defaultLDAPPasswordPrompt)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err) return "", "", fmt.Errorf("error prompting for password: %w", err)
} }
@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
h.logger.V(debugLogLevel).Error(err, "could not open browser") h.logger.V(debugLogLevel).Error(err, "could not open browser")
} }
ctx, cancel := context.WithCancel(h.ctx)
defer cancel()
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
h.promptForWebLogin(ctx, authorizeURL, os.Stderr) ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
defer func() {
cancel()
cleanupPrompt()
}()
// Wait for either the web callback, a pasted auth code, or a timeout. // Wait for either the web callback, a pasted auth code, or a timeout.
select { select {
@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
} }
} }
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) { func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste, // If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it. // since we have no way of reading it.
if !h.isTTY(stdin()) { if !h.isTTY(stdin()) {
return return func() {}
} }
// If the server didn't support response_mode=form_post, don't bother prompting for the manual // If the server didn't support response_mode=form_post, don't bother prompting for the manual
// code because the user isn't going to have any easy way to manually copy it anyway. // code because the user isn't going to have any easy way to manually copy it anyway.
if !h.useFormPost { if !h.useFormPost {
return return func() {}
} }
// Launch the manual auth code prompt in a background goroutine, which will be cancelled // Launch the manual auth code prompt in a background goroutine, which will be cancelled
// if the parent context is cancelled (when the login succeeds or times out). // if the parent context is cancelled (when the login succeeds or times out).
var wg sync.WaitGroup
wg.Add(1)
go func() { go func() {
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ") defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr)
wg.Done()
}()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
if err != nil { if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return return
} }
@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
token, err := h.redeemAuthCode(ctx, code) token, err := h.redeemAuthCode(ctx, code)
h.callbacks <- callbackResult{token: token, err: err} h.callbacks <- callbackResult{token: token, err: err}
}() }()
return wg.Wait
} }
func promptForValue(ctx context.Context, promptLabel string) (string, error) { func promptForValue(ctx context.Context, promptLabel string) (string, error) {
@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
// If the context is canceled, set the read deadline on stdin so the read immediately finishes. type readResult struct {
ctx, cancel := context.WithCancel(ctx) text string
defer cancel() err error
}
readResults := make(chan readResult)
go func() { go func() {
<-ctx.Done() text, err := bufio.NewReader(os.Stdin).ReadString('\n')
_ = os.Stdin.SetReadDeadline(time.Now()) readResults <- readResult{text, err}
close(readResults)
}() }()
text, err := bufio.NewReader(os.Stdin).ReadString('\n') // If the context is canceled, return immediately. The ReadString() operation will stay hung in the background
if err != nil { // goroutine indefinitely.
return "", fmt.Errorf("could read input from stdin: %w", err) ctx, cancel := context.WithCancel(ctx)
defer cancel()
select {
case <-ctx.Done():
return "", ctx.Err()
case r := <-readResults:
return strings.TrimSpace(r.text), r.err
} }
text = strings.TrimSpace(text)
return text, nil
} }
func promptForSecret(ctx context.Context, promptLabel string) (string, error) { func promptForSecret(promptLabel string) (string, error) {
if !term.IsTerminal(stdin()) { if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
//
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
// on stderr is formatted correctly.
_, _ = fmt.Fprint(os.Stderr, "\n")
}()
password, err := term.ReadPassword(stdin()) password, err := term.ReadPassword(stdin())
if err != nil { if err != nil {
return "", fmt.Errorf("could not read password: %w", err) return "", fmt.Errorf("could not read password: %w", err)
} }
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n")
if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err)
}
return string(password), err return string(password), err
} }
@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() {
} }
go func() { _ = srv.Serve(listener) }() go func() { _ = srv.Serve(listener) }()
return func() { return func() {
// Gracefully shut down the server, allowing up to 5 00ms for // Gracefully shut down the server, allowing up to 100ms for
// clients to receive any in-flight responses. // clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond) shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond)
_ = srv.Shutdown(shutdownCtx) _ = srv.Shutdown(shutdownCtx)
cancel() cancel()
} }

View File

@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil } h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil} cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode")) require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error") return fmt.Errorf("some browser open error")
} }
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode")) require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil return nil
} }
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") } h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
return nil return nil
} }
}, },
@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil return "some-upstream-username", nil
} }
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel) require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil return "some-upstream-password", nil
} }
@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel) assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "invalid", nil return "invalid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "valid", nil return "valid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}

View File

@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) {
// Wait for the subprocess to print the login prompt. // Wait for the subprocess to print the login prompt.
t.Logf("waiting for CLI to output login URL and manual prompt") t.Logf("waiting for CLI to output login URL and manual prompt")
output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ") output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ")
require.Contains(t, output, "Log in by visiting this link:") require.Contains(t, output, "Log in by visiting this link:")
require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ") require.Contains(t, output, "Optionally, paste your authorization code: ")
// Find the line with the login URL. // Find the line with the login URL.
var loginURL string var loginURL string