Add ctx params to promptForValue() and promptForSecret().

This allows the prompts to be cancelled, which we need to be able to do in the case where we prompt for a manually-pasted auth code but the automatic callback succeeds.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2021-06-30 15:06:37 -05:00
parent 9fba8d2203
commit 95ee9f0b00
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
2 changed files with 40 additions and 20 deletions

View File

@ -17,6 +17,7 @@ import (
"os"
"sort"
"strings"
"syscall"
"time"
"github.com/coreos/go-oidc/v3/oidc"
@ -99,8 +100,8 @@ type handlerState struct {
openURL func(string) error
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
promptForValue func(promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
callbacks chan callbackResult
}
@ -377,11 +378,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
// Ask the user for their username and password.
username, err := h.promptForValue(defaultLDAPUsernamePrompt)
username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt)
if err != nil {
return nil, fmt.Errorf("error prompting for username: %w", err)
}
password, err := h.promptForSecret(defaultLDAPPasswordPrompt)
password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
if err != nil {
return nil, fmt.Errorf("error prompting for password: %w", err)
}
@ -517,7 +518,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
}
}
func promptForValue(promptLabel string) (string, error) {
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal")
}
@ -525,6 +526,15 @@ func promptForValue(promptLabel string) (string, error) {
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
}()
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
if err != nil {
return "", fmt.Errorf("could read input from stdin: %w", err)
@ -533,7 +543,7 @@ func promptForValue(promptLabel string) (string, error) {
return text, nil
}
func promptForSecret(promptLabel string) (string, error) {
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal")
}
@ -541,16 +551,26 @@ func promptForSecret(promptLabel string) (string, error) {
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
password, err := term.ReadPassword(0)
if err != nil {
return "", fmt.Errorf("could not read password: %w", err)
}
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = os.Stdin.SetReadDeadline(time.Now())
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n")
//
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
// on stderr is formatted correctly.
_, _ = fmt.Fprint(os.Stderr, "\n")
}()
password, err := term.ReadPassword(syscall.Stdin)
if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err)
return "", fmt.Errorf("could not read password: %w", err)
}
return string(password), err
}

View File

@ -248,8 +248,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
@ -751,7 +751,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error")
}
@ -768,7 +768,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") }
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
@ -954,11 +954,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}