Merge pull request #1691 from vmware-tanzu/jtc/display-idp-name-when-prompting-for-login-181927293

Display IDP name when prompting for username and password
This commit is contained in:
Ryan Richard 2023-10-09 21:12:49 -07:00 committed by GitHub
commit 521dec2e04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 222 additions and 112 deletions

View File

@ -56,8 +56,8 @@ const (
// we set this to be relatively long. // we set this to be relatively long.
overallTimeout = 90 * time.Minute overallTimeout = 90 * time.Minute
defaultLDAPUsernamePrompt = "Username: " usernamePrompt = "Username: "
defaultLDAPPasswordPrompt = "Password: " passwordPrompt = "Password: "
// For CLI-based auth, such as with LDAP upstream identity providers, the user may use these environment variables // For CLI-based auth, such as with LDAP upstream identity providers, the user may use these environment variables
// to avoid getting interactively prompted for username and password. // to avoid getting interactively prompted for username and password.
@ -78,6 +78,7 @@ type handlerState struct {
clientID string clientID string
scopes []string scopes []string
cache SessionCache cache SessionCache
out io.Writer
upstreamIdentityProviderName string upstreamIdentityProviderName string
upstreamIdentityProviderType string upstreamIdentityProviderType string
@ -109,8 +110,8 @@ type handlerState struct {
isTTY func(int) bool isTTY func(int) bool
getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error) validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error) promptForValue func(ctx context.Context, promptLabel string, out io.Writer) (string, error)
promptForSecret func(promptLabel string) (string, error) promptForSecret func(promptLabel string, out io.Writer) (string, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -292,6 +293,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
}, },
promptForValue: promptForValue, promptForValue: promptForValue,
promptForSecret: promptForSecret, promptForSecret: promptForSecret,
out: os.Stderr,
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -511,9 +513,13 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (
func (h *handlerState) getUsernameAndPassword() (string, string, error) { func (h *handlerState) getUsernameAndPassword() (string, string, error) {
var err error var err error
if h.upstreamIdentityProviderName != "" {
_, _ = fmt.Fprintf(h.out, "\nLog in to %s\n\n", h.upstreamIdentityProviderName)
}
username := h.getEnv(defaultUsernameEnvVarName) username := h.getEnv(defaultUsernameEnvVarName)
if username == "" { if username == "" {
username, err = h.promptForValue(h.ctx, defaultLDAPUsernamePrompt) username, err = h.promptForValue(h.ctx, usernamePrompt, h.out)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error prompting for username: %w", err) return "", "", fmt.Errorf("error prompting for username: %w", err)
} }
@ -523,7 +529,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
password := h.getEnv(defaultPasswordEnvVarName) password := h.getEnv(defaultPasswordEnvVarName)
if password == "" { if password == "" {
password, err = h.promptForSecret(defaultLDAPPasswordPrompt) password, err = h.promptForSecret(passwordPrompt, h.out)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err) return "", "", fmt.Errorf("error prompting for password: %w", err)
} }
@ -581,7 +587,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
ctx, cancel := context.WithCancel(h.ctx) ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr) cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL)
defer func() { defer func() {
cancel() cancel()
cleanupPrompt() cleanupPrompt()
@ -599,8 +605,8 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
} }
} }
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() { func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string) func() {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) _, _ = fmt.Fprintf(h.out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste, // If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it. // since we have no way of reading it.
@ -621,15 +627,15 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
go func() { go func() {
defer func() { defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts. // Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr) _, _ = fmt.Fprintln(h.out)
wg.Done() wg.Done()
}() }()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ") code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ", h.out)
if err != nil { if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing // Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter". // newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n") _, _ = fmt.Fprint(h.out, "[...]\n")
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return return
@ -642,11 +648,11 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
return wg.Wait return wg.Wait
} }
func promptForValue(ctx context.Context, promptLabel string) (string, error) { func promptForValue(ctx context.Context, promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) { if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
_, err := fmt.Fprint(os.Stderr, promptLabel) _, err := fmt.Fprint(out, promptLabel)
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
@ -674,11 +680,11 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
} }
} }
func promptForSecret(promptLabel string) (string, error) { func promptForSecret(promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) { if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
_, err := fmt.Fprint(os.Stderr, promptLabel) _, err := fmt.Fprint(out, promptLabel)
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
@ -689,7 +695,7 @@ func promptForSecret(promptLabel string) (string, error) {
// term.ReadPassword swallows the newline that was typed by the user, so to // term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password // avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline. // prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n") _, err = fmt.Fprint(out, "\n")
if err != nil { if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err) return "", fmt.Errorf("could not print newline to stderr: %w", err)
} }

View File

@ -15,6 +15,8 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"regexp"
"strings" "strings"
"syscall" "syscall"
"testing" "testing"
@ -77,6 +79,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
time1Unix := int64(2075807775) time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
const testCodeChallenge = "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
testToken := oidctypes.Token{ testToken := oidctypes.Token{
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))}, AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
@ -316,8 +323,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil } h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil } return "some-upstream-username", nil
}
h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil} cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -359,6 +368,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
wantErr string wantErr string
wantToken *oidctypes.Token wantToken *oidctypes.Token
wantLogs []string wantLogs []string
wantStdErr string
}{ }{
{ {
name: "option error", name: "option error",
@ -699,6 +709,9 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
name: "listening disabled and manual prompt fails", name: "listening disabled and manual prompt fails",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h)) require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h))
require.NoError(t, WithSkipListen()(h)) require.NoError(t, WithSkipListen()(h))
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
@ -709,7 +722,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode")) require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error") return fmt.Errorf("some browser open error")
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -720,12 +733,24 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open browser" "error"="some browser open error"`, `"msg"="could not open browser" "error"="some browser open error"`,
}, },
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+
"&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n[...]\n\n") +
"$",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
}, },
{ {
name: "listen success and manual prompt succeeds", name: "listen success and manual prompt succeeds",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h)) require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h))
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
@ -736,7 +761,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode")) require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil return nil
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -747,12 +772,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`, `"msg"="could not open callback listener" "error"="some listen error"`,
}, },
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+
"&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n[...]\n\n") +
"$",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
}, },
{ {
name: "timeout waiting for callback", name: "timeout waiting for callback",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithClient(newClientForServer(successServer))(h))
ctx, cancel := context.WithCancel(h.ctx) ctx, cancel := context.WithCancel(h.ctx)
@ -767,12 +805,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantErr: "timed out waiting for token callback: context canceled", wantErr: "timed out waiting for token callback: context canceled",
}, },
{ {
name: "callback returns error", name: "callback returns error",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithClient(newClientForServer(successServer))(h))
h.openURL = func(_ string) error { h.openURL = func(_ string) error {
go func() { go func() {
@ -785,6 +836,16 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantErr: "error handling callback: some callback error", wantErr: "error handling callback: some callback error",
}, },
{ {
@ -823,10 +884,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri") actualParams.Del("redirect_uri")
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -849,6 +907,16 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -887,10 +955,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri") actualParams.Del("redirect_uri")
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"response_mode": []string{"form_post"}, "response_mode": []string{"form_post"},
@ -914,6 +979,16 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: formPostSuccessServer.URL, issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -954,10 +1029,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri") actualParams.Del("redirect_uri")
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -982,6 +1054,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=oidc"+
"&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -990,7 +1073,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error") return "", errors.New("some prompt error")
} }
@ -999,6 +1082,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: "error prompting for username: some prompt error", wantErr: "error prompting for username: some prompt error",
}, },
{ {
@ -1007,12 +1091,13 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") } h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "", errors.New("some prompt error") }
return nil return nil
} }
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: "error prompting for password: some prompt error", wantErr: "error prompting for password: some prompt error",
}, },
{ {
@ -1068,8 +1153,12 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: `authorization response error: Get "https://` + successServer.Listener.Addr().String() + wantErr: `authorization response error: Get "https://` + successServer.Listener.Addr().String() +
`/authorize?access_type=offline&client_id=test-client-id&code_challenge=VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code&scope=test-scope&state=test-state": some error fetching authorize endpoint`, `/authorize?access_type=offline&client_id=test-client-id&code_challenge=` + testCodeChallenge +
`&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&` +
`pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code` +
`&scope=test-scope&state=test-state": some error fetching authorize endpoint`,
}, },
{ {
name: "ldap login when the OIDC provider authorization endpoint returns something other than a redirect", name: "ldap login when the OIDC provider authorization endpoint returns something other than a redirect",
@ -1081,6 +1170,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: `error getting authorization: expected to be redirected, but response status was 502 Bad Gateway`, wantErr: `error getting authorization: expected to be redirected, but response status was 502 Bad Gateway`,
}, },
{ {
@ -1098,6 +1188,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: `login failed with code "access_denied": optional-error-description`, wantErr: `login failed with code "access_denied": optional-error-description`,
}, },
{ {
@ -1115,6 +1206,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: `error getting authorization: redirected to the wrong location: http://other-server.example.com/callback?code=foo&state=test-state`, wantErr: `error getting authorization: redirected to the wrong location: http://other-server.example.com/callback?code=foo&state=test-state`,
}, },
{ {
@ -1132,6 +1224,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: `login failed with code "access_denied"`, wantErr: `login failed with code "access_denied"`,
}, },
{ {
@ -1147,6 +1240,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: `missing or invalid state parameter in authorization response: http://127.0.0.1:0/callback?code=foo&state=wrong-state`, wantErr: `missing or invalid state parameter in authorization response: http://127.0.0.1:0/callback?code=foo&state=wrong-state`,
}, },
{ {
@ -1174,6 +1268,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantErr: "could not complete authorization code exchange: some authcode exchange or token validation error", wantErr: "could not complete authorization code exchange: some authcode exchange or token validation error",
}, },
{ {
@ -1198,11 +1293,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.getEnv = func(_ string) string { h.getEnv = func(_ string) string {
return "" // asking for any env var returns empty as if it were unset return "" // asking for any env var returns empty as if it were unset
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil return "some-upstream-username", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Password: ", promptLabel) require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil return "some-upstream-password", nil
} }
@ -1242,10 +1337,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -1275,6 +1367,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -1306,11 +1399,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset return "" // all other env vars are treated as if they are unset
} }
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
@ -1321,7 +1414,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
ClientID: "test-client-id", ClientID: "test-client-id",
Scopes: []string{"test-scope"}, Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback", RedirectURI: "http://localhost:0/callback",
UpstreamProviderName: "some-upstream-name",
} }
t.Cleanup(func() { t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
@ -1330,7 +1422,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}) })
require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h)) require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))
discoveryRequestWasMade := false discoveryRequestWasMade := false
authorizeRequestWasMade := false authorizeRequestWasMade := false
@ -1350,10 +1441,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -1362,8 +1450,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
"access_type": []string{"offline"}, "access_type": []string{"offline"},
"client_id": []string{"test-client-id"}, "client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://127.0.0.1:0/callback"}, "redirect_uri": []string{"http://127.0.0.1:0/callback"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"ldap"},
}, req.URL.Query()) }, req.URL.Query())
return &http.Response{ return &http.Response{
StatusCode: http.StatusFound, StatusCode: http.StatusFound,
@ -1418,11 +1504,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset return "" // all other env vars are treated as if they are unset
} }
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
@ -1462,10 +1548,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -1499,6 +1582,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
"\"level\"=4 \"msg\"=\"Pinniped: Read username from environment variable\" \"name\"=\"PINNIPED_USERNAME\"", "\"level\"=4 \"msg\"=\"Pinniped: Read username from environment variable\" \"name\"=\"PINNIPED_USERNAME\"",
"\"level\"=4 \"msg\"=\"Pinniped: Read password from environment variable\" \"name\"=\"PINNIPED_PASSWORD\"", "\"level\"=4 \"msg\"=\"Pinniped: Read password from environment variable\" \"name\"=\"PINNIPED_PASSWORD\"",
}, },
wantStdErr: "^\nLog in to some-upstream-name\n\n$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -1898,6 +1982,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements
klog.SetLogger(testLogger.Logger) klog.SetLogger(testLogger.Logger)
buffer := bytes.Buffer{}
tok, err := Login(tt.issuer, tt.clientID, tok, err := Login(tt.issuer, tt.clientID,
WithContext(context.Background()), WithContext(context.Background()),
WithListenPort(0), WithListenPort(0),
@ -1905,8 +1990,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
WithSkipBrowserOpen(), WithSkipBrowserOpen(),
tt.opt(t), tt.opt(t),
WithLogger(testLogger.Logger), WithLogger(testLogger.Logger),
withOutWriter(t, &buffer),
) )
testLogger.Expect(tt.wantLogs) testLogger.Expect(tt.wantLogs)
if tt.wantStdErr == "" {
require.Empty(t, buffer.String())
} else {
require.Regexp(t, tt.wantStdErr, buffer.String())
}
if tt.wantErr != "" { if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr) require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok) require.Nil(t, tok)
@ -1940,6 +2034,15 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
} }
} }
func withOutWriter(t *testing.T, out io.Writer) Option {
return func(h *handlerState) error {
// Ensure that the proper default value has been set in the handlerState prior to overriding it for tests.
require.Equal(t, os.Stderr, h.out)
h.out = out
return nil
}
}
func TestHandlePasteCallback(t *testing.T) { func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback" const testRedirectURI = "http://127.0.0.1:12324/callback"
@ -1977,7 +2080,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel) assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
@ -1994,7 +2097,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "invalid", nil return "invalid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -2018,7 +2121,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "valid", nil return "valid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -2042,11 +2145,13 @@ func TestHandlePasteCallback(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
buf := &bytes.Buffer{}
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
pkce: pkce.Code("test-pkce"), pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"), nonce: nonce.Nonce("test-nonce"),
out: buf,
} }
if tt.opt != nil { if tt.opt != nil {
require.NoError(t, tt.opt(t)(h)) require.NoError(t, tt.opt(t)(h))
@ -2054,8 +2159,7 @@ func TestHandlePasteCallback(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
var buf bytes.Buffer h.promptForWebLogin(ctx, "https://test-authorize-url/")
h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
require.Equal(t, require.Equal(t,
"Log in by visiting this link:\n\n https://test-authorize-url/\n\n", "Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
buf.String(), buf.String(),