Merge pull request #758 from mattmoyer/use-plain-authcode-prompt

Fix broken TTY after manual auth code prompt.
This commit is contained in:
Matt Moyer 2021-07-30 13:50:27 -05:00 committed by GitHub
commit f4badb3961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 53 deletions

View File

@ -18,6 +18,7 @@ import (
"os"
"sort"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
@ -112,7 +113,7 @@ type handlerState struct {
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error)
callbacks chan callbackResult
}
@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
callbackPath: "/callback",
ctx: context.Background(),
logger: logr.Discard(), // discard logs unless a logger is specified
callbacks: make(chan callbackResult),
callbacks: make(chan callbackResult, 2),
httpClient: http.DefaultClient,
// Default implementations of external dependencies (to be mocked in tests).
@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
password := h.getEnv(defaultPasswordEnvVarName)
if password == "" {
password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
password, err = h.promptForSecret(defaultLDAPPasswordPrompt)
if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err)
}
@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
h.logger.V(debugLogLevel).Error(err, "could not open browser")
}
ctx, cancel := context.WithCancel(h.ctx)
defer cancel()
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
defer func() {
cancel()
cleanupPrompt()
}()
// Wait for either the web callback, a pasted auth code, or a timeout.
select {
@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
}
}
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it.
if !h.isTTY(stdin()) {
return
return func() {}
}
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
// code because the user isn't going to have any easy way to manually copy it anyway.
if !h.useFormPost {
return
return func() {}
}
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
// if the parent context is cancelled (when the login succeeds or times out).
var wg sync.WaitGroup
wg.Add(1)
go func() {
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr)
wg.Done()
}()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return
}
@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
token, err := h.redeemAuthCode(ctx, code)
h.callbacks <- callbackResult{token: token, err: err}
}()
return wg.Wait
}
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
type readResult struct {
text string
err error
}
readResults := make(chan readResult)
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
readResults <- readResult{text, err}
close(readResults)
}()
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", fmt.Errorf("could read input from stdin: %w", err)
// If the context is canceled, return immediately. The ReadString() operation will stay hung in the background
// goroutine indefinitely.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
select {
case <-ctx.Done():
return "", ctx.Err()
case r := <-readResults:
return strings.TrimSpace(r.text), r.err
}
text = strings.TrimSpace(text)
return text, nil
}
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
func promptForSecret(promptLabel string) (string, error) {
if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal")
}
@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
//
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
// on stderr is formatted correctly.
_, _ = fmt.Fprint(os.Stderr, "\n")
}()
password, err := term.ReadPassword(stdin())
if err != nil {
return "", fmt.Errorf("could not read password: %w", err)
}
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n")
if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err)
}
return string(password), err
}
@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() {
}
go func() { _ = srv.Serve(listener) }()
return func() {
// Gracefully shut down the server, allowing up to 5 00ms for
// Gracefully shut down the server, allowing up to 100ms for
// clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond)
shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond)
_ = srv.Shutdown(shutdownCtx)
cancel()
}

View File

@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error")
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}
@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
return nil
@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}

View File

@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) {
// Wait for the subprocess to print the login prompt.
t.Logf("waiting for CLI to output login URL and manual prompt")
output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ")
output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ")
require.Contains(t, output, "Log in by visiting this link:")
require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ")
require.Contains(t, output, "Optionally, paste your authorization code: ")
// Find the line with the login URL.
var loginURL string