diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index cd03e4ee..836f2790 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -17,6 +17,7 @@ import ( "os" "sort" "strings" + "syscall" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -99,8 +100,8 @@ type handlerState struct { openURL func(string) error getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) - promptForValue func(promptLabel string) (string, error) - promptForSecret func(promptLabel string) (string, error) + promptForValue func(ctx context.Context, promptLabel string) (string, error) + promptForSecret func(ctx context.Context, promptLabel string) (string, error) callbacks chan callbackResult } @@ -377,11 +378,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) { // and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error. func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { // Ask the user for their username and password. - username, err := h.promptForValue(defaultLDAPUsernamePrompt) + username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt) if err != nil { return nil, fmt.Errorf("error prompting for username: %w", err) } - password, err := h.promptForSecret(defaultLDAPPasswordPrompt) + password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt) if err != nil { return nil, fmt.Errorf("error prompting for password: %w", err) } @@ -517,7 +518,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp } } -func promptForValue(promptLabel string) (string, error) { +func promptForValue(ctx context.Context, promptLabel string) (string, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { return "", errors.New("stdin is not connected to a terminal") } @@ -525,6 +526,15 @@ func promptForValue(promptLabel string) (string, error) { if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } + + // If the context is canceled, set the read deadline on stdin so the read immediately finishes. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + _ = os.Stdin.SetReadDeadline(time.Now()) + }() + text, err := bufio.NewReader(os.Stdin).ReadString('\n') if err != nil { return "", fmt.Errorf("could read input from stdin: %w", err) @@ -533,7 +543,7 @@ func promptForValue(promptLabel string) (string, error) { return text, nil } -func promptForSecret(promptLabel string) (string, error) { +func promptForSecret(ctx context.Context, promptLabel string) (string, error) { if !term.IsTerminal(int(os.Stdin.Fd())) { return "", errors.New("stdin is not connected to a terminal") } @@ -541,17 +551,27 @@ func promptForSecret(promptLabel string) (string, error) { if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } - password, err := term.ReadPassword(0) + + // If the context is canceled, set the read deadline on stdin so the read immediately finishes. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + <-ctx.Done() + _ = os.Stdin.SetReadDeadline(time.Now()) + + // term.ReadPassword swallows the newline that was typed by the user, so to + // avoid the next line of output from happening on same line as the password + // prompt, we need to print a newline. + // + // Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next + // on stderr is formatted correctly. + _, _ = fmt.Fprint(os.Stderr, "\n") + }() + + password, err := term.ReadPassword(syscall.Stdin) if err != nil { return "", fmt.Errorf("could not read password: %w", err) } - // term.ReadPassword swallows the newline that was typed by the user, so to - // avoid the next line of output from happening on same line as the password - // prompt, we need to print a newline. - _, err = fmt.Fprint(os.Stderr, "\n") - if err != nil { - return "", fmt.Errorf("could not print newline to stderr: %w", err) - } return string(password), err } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 0e3ce673..907935ef 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -248,8 +248,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.generateState = func() (state.State, error) { return "test-state", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } - h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil } - h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil } + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } + h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil } cache := &mockSessionCache{t: t, getReturnsToken: nil} cacheKey := SessionCacheKey{ @@ -751,7 +751,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForValue = func(promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { require.Equal(t, "Username: ", promptLabel) return "", errors.New("some prompt error") } @@ -768,7 +768,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") } + h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") } return nil } }, @@ -954,11 +954,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo h.generateState = func() (state.State, error) { return "test-state", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } - h.promptForValue = func(promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { require.Equal(t, "Username: ", promptLabel) return "some-upstream-username", nil } - h.promptForSecret = func(promptLabel string) (string, error) { + h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) { require.Equal(t, "Password: ", promptLabel) return "some-upstream-password", nil }