Add ctx params to promptForValue() and promptForSecret().
This allows the prompts to be cancelled, which we need to be able to do in the case where we prompt for a manually-pasted auth code but the automatic callback succeeds. Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
9fba8d2203
commit
95ee9f0b00
@ -17,6 +17,7 @@ import (
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
@ -99,8 +100,8 @@ type handlerState struct {
|
||||
openURL func(string) error
|
||||
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
||||
promptForValue func(promptLabel string) (string, error)
|
||||
promptForSecret func(promptLabel string) (string, error)
|
||||
promptForValue func(ctx context.Context, promptLabel string) (string, error)
|
||||
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
|
||||
|
||||
callbacks chan callbackResult
|
||||
}
|
||||
@ -377,11 +378,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
|
||||
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
|
||||
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
|
||||
// Ask the user for their username and password.
|
||||
username, err := h.promptForValue(defaultLDAPUsernamePrompt)
|
||||
username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error prompting for username: %w", err)
|
||||
}
|
||||
password, err := h.promptForSecret(defaultLDAPPasswordPrompt)
|
||||
password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error prompting for password: %w", err)
|
||||
}
|
||||
@ -517,7 +518,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
|
||||
}
|
||||
}
|
||||
|
||||
func promptForValue(promptLabel string) (string, error) {
|
||||
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
return "", errors.New("stdin is not connected to a terminal")
|
||||
}
|
||||
@ -525,6 +526,15 @@ func promptForValue(promptLabel string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||
}
|
||||
|
||||
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = os.Stdin.SetReadDeadline(time.Now())
|
||||
}()
|
||||
|
||||
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could read input from stdin: %w", err)
|
||||
@ -533,7 +543,7 @@ func promptForValue(promptLabel string) (string, error) {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func promptForSecret(promptLabel string) (string, error) {
|
||||
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
return "", errors.New("stdin is not connected to a terminal")
|
||||
}
|
||||
@ -541,16 +551,26 @@ func promptForSecret(promptLabel string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||
}
|
||||
password, err := term.ReadPassword(0)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not read password: %w", err)
|
||||
}
|
||||
|
||||
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = os.Stdin.SetReadDeadline(time.Now())
|
||||
|
||||
// term.ReadPassword swallows the newline that was typed by the user, so to
|
||||
// avoid the next line of output from happening on same line as the password
|
||||
// prompt, we need to print a newline.
|
||||
_, err = fmt.Fprint(os.Stderr, "\n")
|
||||
//
|
||||
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
|
||||
// on stderr is formatted correctly.
|
||||
_, _ = fmt.Fprint(os.Stderr, "\n")
|
||||
}()
|
||||
|
||||
password, err := term.ReadPassword(syscall.Stdin)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not print newline to stderr: %w", err)
|
||||
return "", fmt.Errorf("could not read password: %w", err)
|
||||
}
|
||||
return string(password), err
|
||||
}
|
||||
|
@ -248,8 +248,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
h.generateState = func() (state.State, error) { return "test-state", nil }
|
||||
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
||||
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
||||
h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil }
|
||||
h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil }
|
||||
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
|
||||
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
|
||||
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
||||
cacheKey := SessionCacheKey{
|
||||
@ -751,7 +751,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
||||
h.promptForValue = func(promptLabel string) (string, error) {
|
||||
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||
require.Equal(t, "Username: ", promptLabel)
|
||||
return "", errors.New("some prompt error")
|
||||
}
|
||||
@ -768,7 +768,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
||||
h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") }
|
||||
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
|
||||
return nil
|
||||
}
|
||||
},
|
||||
@ -954,11 +954,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
||||
h.generateState = func() (state.State, error) { return "test-state", nil }
|
||||
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
||||
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
||||
h.promptForValue = func(promptLabel string) (string, error) {
|
||||
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||
require.Equal(t, "Username: ", promptLabel)
|
||||
return "some-upstream-username", nil
|
||||
}
|
||||
h.promptForSecret = func(promptLabel string) (string, error) {
|
||||
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
|
||||
require.Equal(t, "Password: ", promptLabel)
|
||||
return "some-upstream-password", nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user