backfill unit tests for expected stderr output in login_test.go

Co-authored-by: Joshua Casey <joshuatcasey@gmail.com>
This commit is contained in:
Ryan Richard 2023-10-09 13:44:42 -07:00
parent 6ee1e35329
commit 3a21c9a35b
2 changed files with 175 additions and 80 deletions

View File

@ -78,6 +78,7 @@ type handlerState struct {
clientID string clientID string
scopes []string scopes []string
cache SessionCache cache SessionCache
out io.Writer
upstreamIdentityProviderName string upstreamIdentityProviderName string
upstreamIdentityProviderType string upstreamIdentityProviderType string
@ -109,8 +110,8 @@ type handlerState struct {
isTTY func(int) bool isTTY func(int) bool
getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error) validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error) promptForValue func(ctx context.Context, promptLabel string, out io.Writer) (string, error)
promptForSecret func(promptLabel string) (string, error) promptForSecret func(promptLabel string, out io.Writer) (string, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -292,6 +293,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
}, },
promptForValue: promptForValue, promptForValue: promptForValue,
promptForSecret: promptForSecret, promptForSecret: promptForSecret,
out: os.Stderr,
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -513,7 +515,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
username := h.getEnv(defaultUsernameEnvVarName) username := h.getEnv(defaultUsernameEnvVarName)
if username == "" { if username == "" {
username, err = h.promptForValue(h.ctx, usernamePrompt) username, err = h.promptForValue(h.ctx, usernamePrompt, h.out)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error prompting for username: %w", err) return "", "", fmt.Errorf("error prompting for username: %w", err)
} }
@ -523,7 +525,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
password := h.getEnv(defaultPasswordEnvVarName) password := h.getEnv(defaultPasswordEnvVarName)
if password == "" { if password == "" {
password, err = h.promptForSecret(passwordPrompt) password, err = h.promptForSecret(passwordPrompt, h.out)
if err != nil { if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err) return "", "", fmt.Errorf("error prompting for password: %w", err)
} }
@ -581,7 +583,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
ctx, cancel := context.WithCancel(h.ctx) ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr) cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL)
defer func() { defer func() {
cancel() cancel()
cleanupPrompt() cleanupPrompt()
@ -599,8 +601,8 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
} }
} }
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() { func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string) func() {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) _, _ = fmt.Fprintf(h.out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste, // If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it. // since we have no way of reading it.
@ -621,15 +623,15 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
go func() { go func() {
defer func() { defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts. // Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr) _, _ = fmt.Fprintln(h.out)
wg.Done() wg.Done()
}() }()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ") code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ", h.out)
if err != nil { if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing // Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter". // newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n") _, _ = fmt.Fprint(h.out, "[...]\n")
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return return
@ -642,11 +644,11 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
return wg.Wait return wg.Wait
} }
func promptForValue(ctx context.Context, promptLabel string) (string, error) { func promptForValue(ctx context.Context, promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) { if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
_, err := fmt.Fprint(os.Stderr, promptLabel) _, err := fmt.Fprint(out, promptLabel)
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
@ -674,11 +676,11 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
} }
} }
func promptForSecret(promptLabel string) (string, error) { func promptForSecret(promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) { if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
} }
_, err := fmt.Fprint(os.Stderr, promptLabel) _, err := fmt.Fprint(out, promptLabel)
if err != nil { if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err) return "", fmt.Errorf("could not print prompt to stderr: %w", err)
} }
@ -689,7 +691,7 @@ func promptForSecret(promptLabel string) (string, error) {
// term.ReadPassword swallows the newline that was typed by the user, so to // term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password // avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline. // prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n") _, err = fmt.Fprint(out, "\n")
if err != nil { if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err) return "", fmt.Errorf("could not print newline to stderr: %w", err)
} }

View File

@ -15,6 +15,8 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"regexp"
"strings" "strings"
"syscall" "syscall"
"testing" "testing"
@ -77,6 +79,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
time1Unix := int64(2075807775) time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
const testCodeChallenge = "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
testToken := oidctypes.Token{ testToken := oidctypes.Token{
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))}, AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
@ -316,8 +323,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil } h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil } return "some-upstream-username", nil
}
h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil} cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -352,13 +361,14 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
} }
tests := []struct { tests := []struct {
name string name string
opt func(t *testing.T) Option opt func(t *testing.T) Option
issuer string issuer string
clientID string clientID string
wantErr string wantErr string
wantToken *oidctypes.Token wantToken *oidctypes.Token
wantLogs []string wantLogs []string
wantStdErr string
}{ }{
{ {
name: "option error", name: "option error",
@ -699,6 +709,9 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
name: "listening disabled and manual prompt fails", name: "listening disabled and manual prompt fails",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h)) require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h))
require.NoError(t, WithSkipListen()(h)) require.NoError(t, WithSkipListen()(h))
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
@ -709,7 +722,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode")) require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error") return fmt.Errorf("some browser open error")
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -720,12 +733,24 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open browser" "error"="some browser open error"`, `"msg"="could not open browser" "error"="some browser open error"`,
}, },
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+
"&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n[...]\n\n") +
"$",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
}, },
{ {
name: "listen success and manual prompt succeeds", name: "listen success and manual prompt succeeds",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h)) require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h))
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
@ -736,7 +761,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode")) require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil return nil
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
return nil return nil
@ -747,12 +772,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`, `"msg"="could not open callback listener" "error"="some listen error"`,
}, },
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+
"&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n[...]\n\n") +
"$",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
}, },
{ {
name: "timeout waiting for callback", name: "timeout waiting for callback",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithClient(newClientForServer(successServer))(h))
ctx, cancel := context.WithCancel(h.ctx) ctx, cancel := context.WithCancel(h.ctx)
@ -767,12 +805,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "timed out waiting for token callback: context canceled", wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantErr: "timed out waiting for token callback: context canceled",
}, },
{ {
name: "callback returns error", name: "callback returns error",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithClient(newClientForServer(successServer))(h))
h.openURL = func(_ string) error { h.openURL = func(_ string) error {
go func() { go func() {
@ -785,7 +836,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error handling callback: some callback error", wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantErr: "error handling callback: some callback error",
}, },
{ {
name: "callback returns success", name: "callback returns success",
@ -823,10 +884,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri") actualParams.Del("redirect_uri")
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -847,8 +905,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil return nil
} }
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -887,10 +955,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri") actualParams.Del("redirect_uri")
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"response_mode": []string{"form_post"}, "response_mode": []string{"form_post"},
@ -912,8 +977,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil return nil
} }
}, },
issuer: formPostSuccessServer.URL, issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -954,10 +1029,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri") actualParams.Del("redirect_uri")
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -980,8 +1052,19 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil return nil
} }
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=oidc"+
"&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
@ -990,7 +1073,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error") return "", errors.New("some prompt error")
} }
@ -1007,7 +1090,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil) _ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") } h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "", errors.New("some prompt error") }
return nil return nil
} }
}, },
@ -1069,7 +1152,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `authorization response error: Get "https://` + successServer.Listener.Addr().String() + wantErr: `authorization response error: Get "https://` + successServer.Listener.Addr().String() +
`/authorize?access_type=offline&client_id=test-client-id&code_challenge=VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code&scope=test-scope&state=test-state": some error fetching authorize endpoint`, `/authorize?access_type=offline&client_id=test-client-id&code_challenge=` + testCodeChallenge +
`&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&` +
`pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code` +
`&scope=test-scope&state=test-state": some error fetching authorize endpoint`,
}, },
{ {
name: "ldap login when the OIDC provider authorization endpoint returns something other than a redirect", name: "ldap login when the OIDC provider authorization endpoint returns something other than a redirect",
@ -1198,11 +1284,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.getEnv = func(_ string) string { h.getEnv = func(_ string) string {
return "" // asking for any env var returns empty as if it were unset return "" // asking for any env var returns empty as if it were unset
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel) require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil return "some-upstream-username", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Password: ", promptLabel) require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil return "some-upstream-password", nil
} }
@ -1242,10 +1328,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -1306,22 +1389,21 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset return "" // all other env vars are treated as if they are unset
} }
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
cache := &mockSessionCache{t: t, getReturnsToken: nil} cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
Issuer: successServer.URL, Issuer: successServer.URL,
ClientID: "test-client-id", ClientID: "test-client-id",
Scopes: []string{"test-scope"}, Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback", RedirectURI: "http://localhost:0/callback",
UpstreamProviderName: "some-upstream-name",
} }
t.Cleanup(func() { t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
@ -1330,7 +1412,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}) })
require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h)) require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))
discoveryRequestWasMade := false discoveryRequestWasMade := false
authorizeRequestWasMade := false authorizeRequestWasMade := false
@ -1350,10 +1431,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -1362,8 +1440,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
"access_type": []string{"offline"}, "access_type": []string{"offline"},
"client_id": []string{"test-client-id"}, "client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://127.0.0.1:0/callback"}, "redirect_uri": []string{"http://127.0.0.1:0/callback"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"ldap"},
}, req.URL.Query()) }, req.URL.Query())
return &http.Response{ return &http.Response{
StatusCode: http.StatusFound, StatusCode: http.StatusFound,
@ -1418,11 +1494,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset return "" // all other env vars are treated as if they are unset
} }
} }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
h.promptForSecret = func(promptLabel string) (string, error) { h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil return "", nil
} }
@ -1462,10 +1538,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{ require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: "code_challenge": []string{testCodeChallenge},
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"test-scope"}, "scope": []string{"test-scope"},
@ -1898,6 +1971,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements
klog.SetLogger(testLogger.Logger) klog.SetLogger(testLogger.Logger)
buffer := bytes.Buffer{}
tok, err := Login(tt.issuer, tt.clientID, tok, err := Login(tt.issuer, tt.clientID,
WithContext(context.Background()), WithContext(context.Background()),
WithListenPort(0), WithListenPort(0),
@ -1905,8 +1979,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
WithSkipBrowserOpen(), WithSkipBrowserOpen(),
tt.opt(t), tt.opt(t),
WithLogger(testLogger.Logger), WithLogger(testLogger.Logger),
withOutWriter(t, &buffer),
) )
testLogger.Expect(tt.wantLogs) testLogger.Expect(tt.wantLogs)
if tt.wantStdErr == "" {
require.Empty(t, buffer.String())
} else {
require.Regexp(t, tt.wantStdErr, buffer.String())
}
if tt.wantErr != "" { if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr) require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok) require.Nil(t, tok)
@ -1940,6 +2023,15 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
} }
} }
func withOutWriter(t *testing.T, out io.Writer) Option {
return func(h *handlerState) error {
// Ensure that the proper default value has been set in the handlerState prior to overriding it for tests.
require.Equal(t, os.Stderr, h.out)
h.out = out
return nil
}
}
func TestHandlePasteCallback(t *testing.T) { func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback" const testRedirectURI = "http://127.0.0.1:12324/callback"
@ -1977,7 +2069,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel) assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error") return "", fmt.Errorf("some prompt error")
} }
@ -1994,7 +2086,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "invalid", nil return "invalid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -2018,7 +2110,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true } h.isTTY = func(fd int) bool { return true }
h.useFormPost = true h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "valid", nil return "valid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -2042,11 +2134,13 @@ func TestHandlePasteCallback(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
buf := &bytes.Buffer{}
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
pkce: pkce.Code("test-pkce"), pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"), nonce: nonce.Nonce("test-nonce"),
out: buf,
} }
if tt.opt != nil { if tt.opt != nil {
require.NoError(t, tt.opt(t)(h)) require.NoError(t, tt.opt(t)(h))
@ -2054,8 +2148,7 @@ func TestHandlePasteCallback(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
var buf bytes.Buffer h.promptForWebLogin(ctx, "https://test-authorize-url/")
h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
require.Equal(t, require.Equal(t,
"Log in by visiting this link:\n\n https://test-authorize-url/\n\n", "Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
buf.String(), buf.String(),