Fix broken TTY after manual auth code prompt.
This may be a temporary fix. It switches the manual auth code prompt to use `promptForValue()` instead of `promptForSecret()`. The `promptForSecret()` function no longer supports cancellation (the v0.9.2 behavior) and the method of cancelling in `promptForValue()` is now based on running the blocking read in a background goroutine, which is allowed to block forever or leak (which is not important for our CLI use case). This means that the authorization code is now visible in the user's terminal, but this is really not a big deal because of PKCE and the limited lifetime of an auth code. The main goroutine now correctly waits for the "manual prompt" goroutine to clean up, which now includes printing the extra newline that would normally have been entered by the user in the manual flow. The text of the manual login prompt is updated to be more concise and less scary (don't use the word "fail"). Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
0ab8e14e4a
commit
1e32530d7b
@ -18,6 +18,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
@ -112,7 +113,7 @@ type handlerState struct {
|
|||||||
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||||
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
||||||
promptForValue func(ctx context.Context, promptLabel string) (string, error)
|
promptForValue func(ctx context.Context, promptLabel string) (string, error)
|
||||||
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
|
promptForSecret func(promptLabel string) (string, error)
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
callbackPath: "/callback",
|
callbackPath: "/callback",
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
logger: logr.Discard(), // discard logs unless a logger is specified
|
logger: logr.Discard(), // discard logs unless a logger is specified
|
||||||
callbacks: make(chan callbackResult),
|
callbacks: make(chan callbackResult, 2),
|
||||||
httpClient: http.DefaultClient,
|
httpClient: http.DefaultClient,
|
||||||
|
|
||||||
// Default implementations of external dependencies (to be mocked in tests).
|
// Default implementations of external dependencies (to be mocked in tests).
|
||||||
@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
|
|||||||
|
|
||||||
password := h.getEnv(defaultPasswordEnvVarName)
|
password := h.getEnv(defaultPasswordEnvVarName)
|
||||||
if password == "" {
|
if password == "" {
|
||||||
password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
|
password, err = h.promptForSecret(defaultLDAPPasswordPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("error prompting for password: %w", err)
|
return "", "", fmt.Errorf("error prompting for password: %w", err)
|
||||||
}
|
}
|
||||||
@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
|
|||||||
h.logger.V(debugLogLevel).Error(err, "could not open browser")
|
h.logger.V(debugLogLevel).Error(err, "could not open browser")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(h.ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
|
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
|
||||||
h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
|
ctx, cancel := context.WithCancel(h.ctx)
|
||||||
|
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
cleanupPrompt()
|
||||||
|
}()
|
||||||
|
|
||||||
// Wait for either the web callback, a pasted auth code, or a timeout.
|
// Wait for either the web callback, a pasted auth code, or a timeout.
|
||||||
select {
|
select {
|
||||||
@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
|
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() {
|
||||||
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
|
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
|
||||||
|
|
||||||
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
|
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
|
||||||
// since we have no way of reading it.
|
// since we have no way of reading it.
|
||||||
if !h.isTTY(stdin()) {
|
if !h.isTTY(stdin()) {
|
||||||
return
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
|
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
|
||||||
// code because the user isn't going to have any easy way to manually copy it anyway.
|
// code because the user isn't going to have any easy way to manually copy it anyway.
|
||||||
if !h.useFormPost {
|
if !h.useFormPost {
|
||||||
return
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
|
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
|
||||||
// if the parent context is cancelled (when the login succeeds or times out).
|
// if the parent context is cancelled (when the login succeeds or times out).
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
|
defer func() {
|
||||||
|
// Always emit a newline so the kubectl output is visually separated from the login prompts.
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr)
|
||||||
|
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
|
||||||
|
// newline that simulates the user having pressed "enter".
|
||||||
|
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
|
||||||
|
|
||||||
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
|
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
|
|||||||
token, err := h.redeemAuthCode(ctx, code)
|
token, err := h.redeemAuthCode(ctx, code)
|
||||||
h.callbacks <- callbackResult{token: token, err: err}
|
h.callbacks <- callbackResult{token: token, err: err}
|
||||||
}()
|
}()
|
||||||
|
return wg.Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
||||||
@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
|||||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
type readResult struct {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
text string
|
||||||
defer cancel()
|
err error
|
||||||
|
}
|
||||||
|
readResults := make(chan readResult)
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||||
_ = os.Stdin.SetReadDeadline(time.Now())
|
readResults <- readResult{text, err}
|
||||||
|
close(readResults)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
// If the context is canceled, return immediately. The ReadString() operation will stay hung in the background
|
||||||
if err != nil {
|
// goroutine indefinitely.
|
||||||
return "", fmt.Errorf("could read input from stdin: %w", err)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case r := <-readResults:
|
||||||
|
return strings.TrimSpace(r.text), r.err
|
||||||
}
|
}
|
||||||
text = strings.TrimSpace(text)
|
|
||||||
return text, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
|
func promptForSecret(promptLabel string) (string, error) {
|
||||||
if !term.IsTerminal(stdin()) {
|
if !term.IsTerminal(stdin()) {
|
||||||
return "", errors.New("stdin is not connected to a terminal")
|
return "", errors.New("stdin is not connected to a terminal")
|
||||||
}
|
}
|
||||||
@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
_ = os.Stdin.SetReadDeadline(time.Now())
|
|
||||||
|
|
||||||
// term.ReadPassword swallows the newline that was typed by the user, so to
|
|
||||||
// avoid the next line of output from happening on same line as the password
|
|
||||||
// prompt, we need to print a newline.
|
|
||||||
//
|
|
||||||
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
|
|
||||||
// on stderr is formatted correctly.
|
|
||||||
_, _ = fmt.Fprint(os.Stderr, "\n")
|
|
||||||
}()
|
|
||||||
|
|
||||||
password, err := term.ReadPassword(stdin())
|
password, err := term.ReadPassword(stdin())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not read password: %w", err)
|
return "", fmt.Errorf("could not read password: %w", err)
|
||||||
}
|
}
|
||||||
|
// term.ReadPassword swallows the newline that was typed by the user, so to
|
||||||
|
// avoid the next line of output from happening on same line as the password
|
||||||
|
// prompt, we need to print a newline.
|
||||||
|
_, err = fmt.Fprint(os.Stderr, "\n")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not print newline to stderr: %w", err)
|
||||||
|
}
|
||||||
return string(password), err
|
return string(password), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() {
|
|||||||
}
|
}
|
||||||
go func() { _ = srv.Serve(listener) }()
|
go func() { _ = srv.Serve(listener) }()
|
||||||
return func() {
|
return func() {
|
||||||
// Gracefully shut down the server, allowing up to 5 00ms for
|
// Gracefully shut down the server, allowing up to 100ms for
|
||||||
// clients to receive any in-flight responses.
|
// clients to receive any in-flight responses.
|
||||||
shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond)
|
shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond)
|
||||||
_ = srv.Shutdown(shutdownCtx)
|
_ = srv.Shutdown(shutdownCtx)
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
||||||
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
||||||
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
|
||||||
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
|
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
|
||||||
|
|
||||||
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
||||||
cacheKey := SessionCacheKey{
|
cacheKey := SessionCacheKey{
|
||||||
@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
||||||
return fmt.Errorf("some browser open error")
|
return fmt.Errorf("some browser open error")
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "", fmt.Errorf("some prompt error")
|
return "", fmt.Errorf("some prompt error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "", fmt.Errorf("some prompt error")
|
return "", fmt.Errorf("some prompt error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
||||||
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
|
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.Equal(t, "Username: ", promptLabel)
|
require.Equal(t, "Username: ", promptLabel)
|
||||||
return "some-upstream-username", nil
|
return "some-upstream-username", nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
|
h.promptForSecret = func(promptLabel string) (string, error) {
|
||||||
require.Equal(t, "Password: ", promptLabel)
|
require.Equal(t, "Password: ", promptLabel)
|
||||||
return "some-upstream-password", nil
|
return "some-upstream-password", nil
|
||||||
}
|
}
|
||||||
@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
|
h.promptForSecret = func(promptLabel string) (string, error) {
|
||||||
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) {
|
|||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.isTTY = func(fd int) bool { return true }
|
h.isTTY = func(fd int) bool { return true }
|
||||||
h.useFormPost = true
|
h.useFormPost = true
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
|
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
|
||||||
return "", fmt.Errorf("some prompt error")
|
return "", fmt.Errorf("some prompt error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) {
|
|||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.isTTY = func(fd int) bool { return true }
|
h.isTTY = func(fd int) bool { return true }
|
||||||
h.useFormPost = true
|
h.useFormPost = true
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "invalid", nil
|
return "invalid", nil
|
||||||
}
|
}
|
||||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||||
@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) {
|
|||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.isTTY = func(fd int) bool { return true }
|
h.isTTY = func(fd int) bool { return true }
|
||||||
h.useFormPost = true
|
h.useFormPost = true
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "valid", nil
|
return "valid", nil
|
||||||
}
|
}
|
||||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||||
|
@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) {
|
|||||||
|
|
||||||
// Wait for the subprocess to print the login prompt.
|
// Wait for the subprocess to print the login prompt.
|
||||||
t.Logf("waiting for CLI to output login URL and manual prompt")
|
t.Logf("waiting for CLI to output login URL and manual prompt")
|
||||||
output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ")
|
output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ")
|
||||||
require.Contains(t, output, "Log in by visiting this link:")
|
require.Contains(t, output, "Log in by visiting this link:")
|
||||||
require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ")
|
require.Contains(t, output, "Optionally, paste your authorization code: ")
|
||||||
|
|
||||||
// Find the line with the login URL.
|
// Find the line with the login URL.
|
||||||
var loginURL string
|
var loginURL string
|
||||||
|
Loading…
Reference in New Issue
Block a user