Merge pull request #758 from mattmoyer/use-plain-authcode-prompt
Fix broken TTY after manual auth code prompt.
This commit is contained in:
commit
f4badb3961
@ -18,6 +18,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
@ -112,7 +113,7 @@ type handlerState struct {
|
|||||||
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||||
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
||||||
promptForValue func(ctx context.Context, promptLabel string) (string, error)
|
promptForValue func(ctx context.Context, promptLabel string) (string, error)
|
||||||
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
|
promptForSecret func(promptLabel string) (string, error)
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -275,7 +276,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
callbackPath: "/callback",
|
callbackPath: "/callback",
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
logger: logr.Discard(), // discard logs unless a logger is specified
|
logger: logr.Discard(), // discard logs unless a logger is specified
|
||||||
callbacks: make(chan callbackResult),
|
callbacks: make(chan callbackResult, 2),
|
||||||
httpClient: http.DefaultClient,
|
httpClient: http.DefaultClient,
|
||||||
|
|
||||||
// Default implementations of external dependencies (to be mocked in tests).
|
// Default implementations of external dependencies (to be mocked in tests).
|
||||||
@ -520,7 +521,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
|
|||||||
|
|
||||||
password := h.getEnv(defaultPasswordEnvVarName)
|
password := h.getEnv(defaultPasswordEnvVarName)
|
||||||
if password == "" {
|
if password == "" {
|
||||||
password, err = h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
|
password, err = h.promptForSecret(defaultLDAPPasswordPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("error prompting for password: %w", err)
|
return "", "", fmt.Errorf("error prompting for password: %w", err)
|
||||||
}
|
}
|
||||||
@ -576,11 +577,13 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
|
|||||||
h.logger.V(debugLogLevel).Error(err, "could not open browser")
|
h.logger.V(debugLogLevel).Error(err, "could not open browser")
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(h.ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
|
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
|
||||||
h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
|
ctx, cancel := context.WithCancel(h.ctx)
|
||||||
|
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
cleanupPrompt()
|
||||||
|
}()
|
||||||
|
|
||||||
// Wait for either the web callback, a pasted auth code, or a timeout.
|
// Wait for either the web callback, a pasted auth code, or a timeout.
|
||||||
select {
|
select {
|
||||||
@ -594,26 +597,38 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
|
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() {
|
||||||
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
|
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
|
||||||
|
|
||||||
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
|
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
|
||||||
// since we have no way of reading it.
|
// since we have no way of reading it.
|
||||||
if !h.isTTY(stdin()) {
|
if !h.isTTY(stdin()) {
|
||||||
return
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
|
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
|
||||||
// code because the user isn't going to have any easy way to manually copy it anyway.
|
// code because the user isn't going to have any easy way to manually copy it anyway.
|
||||||
if !h.useFormPost {
|
if !h.useFormPost {
|
||||||
return
|
return func() {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
|
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
|
||||||
// if the parent context is cancelled (when the login succeeds or times out).
|
// if the parent context is cancelled (when the login succeeds or times out).
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
|
defer func() {
|
||||||
|
// Always emit a newline so the kubectl output is visually separated from the login prompts.
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr)
|
||||||
|
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
|
||||||
|
// newline that simulates the user having pressed "enter".
|
||||||
|
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
|
||||||
|
|
||||||
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
|
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -622,6 +637,7 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
|
|||||||
token, err := h.redeemAuthCode(ctx, code)
|
token, err := h.redeemAuthCode(ctx, code)
|
||||||
h.callbacks <- callbackResult{token: token, err: err}
|
h.callbacks <- callbackResult{token: token, err: err}
|
||||||
}()
|
}()
|
||||||
|
return wg.Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
||||||
@ -633,23 +649,30 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
|||||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
type readResult struct {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
text string
|
||||||
defer cancel()
|
err error
|
||||||
|
}
|
||||||
|
readResults := make(chan readResult)
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||||
_ = os.Stdin.SetReadDeadline(time.Now())
|
readResults <- readResult{text, err}
|
||||||
|
close(readResults)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
// If the context is canceled, return immediately. The ReadString() operation will stay hung in the background
|
||||||
if err != nil {
|
// goroutine indefinitely.
|
||||||
return "", fmt.Errorf("could read input from stdin: %w", err)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case r := <-readResults:
|
||||||
|
return strings.TrimSpace(r.text), r.err
|
||||||
}
|
}
|
||||||
text = strings.TrimSpace(text)
|
|
||||||
return text, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
|
func promptForSecret(promptLabel string) (string, error) {
|
||||||
if !term.IsTerminal(stdin()) {
|
if !term.IsTerminal(stdin()) {
|
||||||
return "", errors.New("stdin is not connected to a terminal")
|
return "", errors.New("stdin is not connected to a terminal")
|
||||||
}
|
}
|
||||||
@ -657,27 +680,17 @@ func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
_ = os.Stdin.SetReadDeadline(time.Now())
|
|
||||||
|
|
||||||
// term.ReadPassword swallows the newline that was typed by the user, so to
|
|
||||||
// avoid the next line of output from happening on same line as the password
|
|
||||||
// prompt, we need to print a newline.
|
|
||||||
//
|
|
||||||
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
|
|
||||||
// on stderr is formatted correctly.
|
|
||||||
_, _ = fmt.Fprint(os.Stderr, "\n")
|
|
||||||
}()
|
|
||||||
|
|
||||||
password, err := term.ReadPassword(stdin())
|
password, err := term.ReadPassword(stdin())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not read password: %w", err)
|
return "", fmt.Errorf("could not read password: %w", err)
|
||||||
}
|
}
|
||||||
|
// term.ReadPassword swallows the newline that was typed by the user, so to
|
||||||
|
// avoid the next line of output from happening on same line as the password
|
||||||
|
// prompt, we need to print a newline.
|
||||||
|
_, err = fmt.Fprint(os.Stderr, "\n")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not print newline to stderr: %w", err)
|
||||||
|
}
|
||||||
return string(password), err
|
return string(password), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -882,9 +895,9 @@ func (h *handlerState) serve(listener net.Listener) func() {
|
|||||||
}
|
}
|
||||||
go func() { _ = srv.Serve(listener) }()
|
go func() { _ = srv.Serve(listener) }()
|
||||||
return func() {
|
return func() {
|
||||||
// Gracefully shut down the server, allowing up to 5 00ms for
|
// Gracefully shut down the server, allowing up to 100ms for
|
||||||
// clients to receive any in-flight responses.
|
// clients to receive any in-flight responses.
|
||||||
shutdownCtx, cancel := context.WithTimeout(h.ctx, 500*time.Millisecond)
|
shutdownCtx, cancel := context.WithTimeout(h.ctx, 100*time.Millisecond)
|
||||||
_ = srv.Shutdown(shutdownCtx)
|
_ = srv.Shutdown(shutdownCtx)
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
|
@ -251,7 +251,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
||||||
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
||||||
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
|
||||||
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
|
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
|
||||||
|
|
||||||
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
||||||
cacheKey := SessionCacheKey{
|
cacheKey := SessionCacheKey{
|
||||||
@ -541,7 +541,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
||||||
return fmt.Errorf("some browser open error")
|
return fmt.Errorf("some browser open error")
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "", fmt.Errorf("some prompt error")
|
return "", fmt.Errorf("some prompt error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -567,7 +567,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "", fmt.Errorf("some prompt error")
|
return "", fmt.Errorf("some prompt error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -825,7 +825,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
||||||
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
|
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -1018,7 +1018,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.Equal(t, "Username: ", promptLabel)
|
require.Equal(t, "Username: ", promptLabel)
|
||||||
return "some-upstream-username", nil
|
return "some-upstream-username", nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
|
h.promptForSecret = func(promptLabel string) (string, error) {
|
||||||
require.Equal(t, "Password: ", promptLabel)
|
require.Equal(t, "Password: ", promptLabel)
|
||||||
return "some-upstream-password", nil
|
return "some-upstream-password", nil
|
||||||
}
|
}
|
||||||
@ -1125,7 +1125,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
|
h.promptForSecret = func(promptLabel string) (string, error) {
|
||||||
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
@ -1634,8 +1634,8 @@ func TestHandlePasteCallback(t *testing.T) {
|
|||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.isTTY = func(fd int) bool { return true }
|
h.isTTY = func(fd int) bool { return true }
|
||||||
h.useFormPost = true
|
h.useFormPost = true
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
|
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
|
||||||
return "", fmt.Errorf("some prompt error")
|
return "", fmt.Errorf("some prompt error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -1651,7 +1651,7 @@ func TestHandlePasteCallback(t *testing.T) {
|
|||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.isTTY = func(fd int) bool { return true }
|
h.isTTY = func(fd int) bool { return true }
|
||||||
h.useFormPost = true
|
h.useFormPost = true
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "invalid", nil
|
return "invalid", nil
|
||||||
}
|
}
|
||||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||||
@ -1675,7 +1675,7 @@ func TestHandlePasteCallback(t *testing.T) {
|
|||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
h.isTTY = func(fd int) bool { return true }
|
h.isTTY = func(fd int) bool { return true }
|
||||||
h.useFormPost = true
|
h.useFormPost = true
|
||||||
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
return "valid", nil
|
return "valid", nil
|
||||||
}
|
}
|
||||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||||
|
@ -331,9 +331,9 @@ func TestE2EFullIntegration(t *testing.T) {
|
|||||||
|
|
||||||
// Wait for the subprocess to print the login prompt.
|
// Wait for the subprocess to print the login prompt.
|
||||||
t.Logf("waiting for CLI to output login URL and manual prompt")
|
t.Logf("waiting for CLI to output login URL and manual prompt")
|
||||||
output := readFromFileUntilStringIsSeen(t, ptyFile, "If automatic login fails, paste your authorization code to login manually: ")
|
output := readFromFileUntilStringIsSeen(t, ptyFile, "Optionally, paste your authorization code: ")
|
||||||
require.Contains(t, output, "Log in by visiting this link:")
|
require.Contains(t, output, "Log in by visiting this link:")
|
||||||
require.Contains(t, output, "If automatic login fails, paste your authorization code to login manually: ")
|
require.Contains(t, output, "Optionally, paste your authorization code: ")
|
||||||
|
|
||||||
// Find the line with the login URL.
|
// Find the line with the login URL.
|
||||||
var loginURL string
|
var loginURL string
|
||||||
|
Loading…
Reference in New Issue
Block a user