Add ctx params to promptForValue() and promptForSecret().
This allows the prompts to be cancelled, which we need to be able to do in the case where we prompt for a manually-pasted auth code but the automatic callback succeeds. Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
9fba8d2203
commit
95ee9f0b00
@ -17,6 +17,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
@ -99,8 +100,8 @@ type handlerState struct {
|
|||||||
openURL func(string) error
|
openURL func(string) error
|
||||||
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||||
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
||||||
promptForValue func(promptLabel string) (string, error)
|
promptForValue func(ctx context.Context, promptLabel string) (string, error)
|
||||||
promptForSecret func(promptLabel string) (string, error)
|
promptForSecret func(ctx context.Context, promptLabel string) (string, error)
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -377,11 +378,11 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
|
|||||||
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
|
// and parse the authcode from the response. Exchange the authcode for tokens. Return the tokens or an error.
|
||||||
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
|
func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
|
||||||
// Ask the user for their username and password.
|
// Ask the user for their username and password.
|
||||||
username, err := h.promptForValue(defaultLDAPUsernamePrompt)
|
username, err := h.promptForValue(h.ctx, defaultLDAPUsernamePrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error prompting for username: %w", err)
|
return nil, fmt.Errorf("error prompting for username: %w", err)
|
||||||
}
|
}
|
||||||
password, err := h.promptForSecret(defaultLDAPPasswordPrompt)
|
password, err := h.promptForSecret(h.ctx, defaultLDAPPasswordPrompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error prompting for password: %w", err)
|
return nil, fmt.Errorf("error prompting for password: %w", err)
|
||||||
}
|
}
|
||||||
@ -517,7 +518,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptForValue(promptLabel string) (string, error) {
|
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
|
||||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
return "", errors.New("stdin is not connected to a terminal")
|
return "", errors.New("stdin is not connected to a terminal")
|
||||||
}
|
}
|
||||||
@ -525,6 +526,15 @@ func promptForValue(promptLabel string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
_ = os.Stdin.SetReadDeadline(time.Now())
|
||||||
|
}()
|
||||||
|
|
||||||
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
text, err := bufio.NewReader(os.Stdin).ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could read input from stdin: %w", err)
|
return "", fmt.Errorf("could read input from stdin: %w", err)
|
||||||
@ -533,7 +543,7 @@ func promptForValue(promptLabel string) (string, error) {
|
|||||||
return text, nil
|
return text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func promptForSecret(promptLabel string) (string, error) {
|
func promptForSecret(ctx context.Context, promptLabel string) (string, error) {
|
||||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||||
return "", errors.New("stdin is not connected to a terminal")
|
return "", errors.New("stdin is not connected to a terminal")
|
||||||
}
|
}
|
||||||
@ -541,16 +551,26 @@ func promptForSecret(promptLabel string) (string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
|
||||||
}
|
}
|
||||||
password, err := term.ReadPassword(0)
|
|
||||||
if err != nil {
|
// If the context is canceled, set the read deadline on stdin so the read immediately finishes.
|
||||||
return "", fmt.Errorf("could not read password: %w", err)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
}
|
defer cancel()
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
_ = os.Stdin.SetReadDeadline(time.Now())
|
||||||
|
|
||||||
// term.ReadPassword swallows the newline that was typed by the user, so to
|
// term.ReadPassword swallows the newline that was typed by the user, so to
|
||||||
// avoid the next line of output from happening on same line as the password
|
// avoid the next line of output from happening on same line as the password
|
||||||
// prompt, we need to print a newline.
|
// prompt, we need to print a newline.
|
||||||
_, err = fmt.Fprint(os.Stderr, "\n")
|
//
|
||||||
|
// Even if the read was cancelled prematurely, we still want to echo a newline so whatever comes next
|
||||||
|
// on stderr is formatted correctly.
|
||||||
|
_, _ = fmt.Fprint(os.Stderr, "\n")
|
||||||
|
}()
|
||||||
|
|
||||||
|
password, err := term.ReadPassword(syscall.Stdin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("could not print newline to stderr: %w", err)
|
return "", fmt.Errorf("could not read password: %w", err)
|
||||||
}
|
}
|
||||||
return string(password), err
|
return string(password), err
|
||||||
}
|
}
|
||||||
|
@ -248,8 +248,8 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
h.generateState = func() (state.State, error) { return "test-state", nil }
|
h.generateState = func() (state.State, error) { return "test-state", nil }
|
||||||
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
||||||
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
||||||
h.promptForValue = func(promptLabel string) (string, error) { return "some-upstream-username", nil }
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
|
||||||
h.promptForSecret = func(promptLabel string) (string, error) { return "some-upstream-password", nil }
|
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "some-upstream-password", nil }
|
||||||
|
|
||||||
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
cache := &mockSessionCache{t: t, getReturnsToken: nil}
|
||||||
cacheKey := SessionCacheKey{
|
cacheKey := SessionCacheKey{
|
||||||
@ -751,7 +751,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
||||||
h.promptForValue = func(promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
require.Equal(t, "Username: ", promptLabel)
|
require.Equal(t, "Username: ", promptLabel)
|
||||||
return "", errors.New("some prompt error")
|
return "", errors.New("some prompt error")
|
||||||
}
|
}
|
||||||
@ -768,7 +768,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
_ = defaultLDAPTestOpts(t, h, nil, nil)
|
||||||
h.promptForSecret = func(promptLabel string) (string, error) { return "", errors.New("some prompt error") }
|
h.promptForSecret = func(_ context.Context, _ string) (string, error) { return "", errors.New("some prompt error") }
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -954,11 +954,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
|
|||||||
h.generateState = func() (state.State, error) { return "test-state", nil }
|
h.generateState = func() (state.State, error) { return "test-state", nil }
|
||||||
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
|
||||||
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
|
||||||
h.promptForValue = func(promptLabel string) (string, error) {
|
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
require.Equal(t, "Username: ", promptLabel)
|
require.Equal(t, "Username: ", promptLabel)
|
||||||
return "some-upstream-username", nil
|
return "some-upstream-username", nil
|
||||||
}
|
}
|
||||||
h.promptForSecret = func(promptLabel string) (string, error) {
|
h.promptForSecret = func(_ context.Context, promptLabel string) (string, error) {
|
||||||
require.Equal(t, "Password: ", promptLabel)
|
require.Equal(t, "Password: ", promptLabel)
|
||||||
return "some-upstream-password", nil
|
return "some-upstream-password", nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user