From 3a21c9a35b2aab5b7c0a84845f3431df98d37825 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Mon, 9 Oct 2023 13:44:42 -0700 Subject: [PATCH] backfill unit tests for expected stderr output in login_test.go Co-authored-by: Joshua Casey --- pkg/oidcclient/login.go | 32 ++--- pkg/oidcclient/login_test.go | 223 +++++++++++++++++++++++++---------- 2 files changed, 175 insertions(+), 80 deletions(-) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index e3798668..318e226a 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -78,6 +78,7 @@ type handlerState struct { clientID string scopes []string cache SessionCache + out io.Writer upstreamIdentityProviderName string upstreamIdentityProviderType string @@ -109,8 +110,8 @@ type handlerState struct { isTTY func(int) bool getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error) - promptForValue func(ctx context.Context, promptLabel string) (string, error) - promptForSecret func(promptLabel string) (string, error) + promptForValue func(ctx context.Context, promptLabel string, out io.Writer) (string, error) + promptForSecret func(promptLabel string, out io.Writer) (string, error) callbacks chan callbackResult } @@ -292,6 +293,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er }, promptForValue: promptForValue, promptForSecret: promptForSecret, + out: os.Stderr, } for _, opt := range opts { if err := opt(&h); err != nil { @@ -513,7 +515,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) { username := h.getEnv(defaultUsernameEnvVarName) if username == "" { - username, err = h.promptForValue(h.ctx, usernamePrompt) + username, err = h.promptForValue(h.ctx, usernamePrompt, h.out) if err != nil { return "", "", fmt.Errorf("error prompting for username: %w", err) } @@ -523,7 +525,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) { password := h.getEnv(defaultPasswordEnvVarName) if password == "" { - password, err = h.promptForSecret(passwordPrompt) + password, err = h.promptForSecret(passwordPrompt, h.out) if err != nil { return "", "", fmt.Errorf("error prompting for password: %w", err) } @@ -581,7 +583,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp // Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible). ctx, cancel := context.WithCancel(h.ctx) - cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr) + cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL) defer func() { cancel() cleanupPrompt() @@ -599,8 +601,8 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp } } -func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() { - _, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) +func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string) func() { + _, _ = fmt.Fprintf(h.out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL) // If stdin is not a TTY, print the URL but don't prompt for the manual paste, // since we have no way of reading it. @@ -621,15 +623,15 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin go func() { defer func() { // Always emit a newline so the kubectl output is visually separated from the login prompts. - _, _ = fmt.Fprintln(os.Stderr) + _, _ = fmt.Fprintln(h.out) wg.Done() }() - code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ") + code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ", h.out) if err != nil { // Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing // newline that simulates the user having pressed "enter". - _, _ = fmt.Fprint(os.Stderr, "[...]\n") + _, _ = fmt.Fprint(h.out, "[...]\n") h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)} return @@ -642,11 +644,11 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin return wg.Wait } -func promptForValue(ctx context.Context, promptLabel string) (string, error) { +func promptForValue(ctx context.Context, promptLabel string, out io.Writer) (string, error) { if !term.IsTerminal(stdin()) { return "", errors.New("stdin is not connected to a terminal") } - _, err := fmt.Fprint(os.Stderr, promptLabel) + _, err := fmt.Fprint(out, promptLabel) if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } @@ -674,11 +676,11 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) { } } -func promptForSecret(promptLabel string) (string, error) { +func promptForSecret(promptLabel string, out io.Writer) (string, error) { if !term.IsTerminal(stdin()) { return "", errors.New("stdin is not connected to a terminal") } - _, err := fmt.Fprint(os.Stderr, promptLabel) + _, err := fmt.Fprint(out, promptLabel) if err != nil { return "", fmt.Errorf("could not print prompt to stderr: %w", err) } @@ -689,7 +691,7 @@ func promptForSecret(promptLabel string) (string, error) { // term.ReadPassword swallows the newline that was typed by the user, so to // avoid the next line of output from happening on same line as the password // prompt, we need to print a newline. - _, err = fmt.Fprint(os.Stderr, "\n") + _, err = fmt.Fprint(out, "\n") if err != nil { return "", fmt.Errorf("could not print newline to stderr: %w", err) } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index a7268d09..4a02f5fd 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -15,6 +15,8 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" + "regexp" "strings" "syscall" "testing" @@ -77,6 +79,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo time1Unix := int64(2075807775) require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) + // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: + // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 + // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g + const testCodeChallenge = "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g" + testToken := oidctypes.Token{ AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))}, RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"}, @@ -316,8 +323,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo h.generateState = func() (state.State, error) { return "test-state", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil } - h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil } + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { + return "some-upstream-username", nil + } + h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "some-upstream-password", nil } cache := &mockSessionCache{t: t, getReturnsToken: nil} cacheKey := SessionCacheKey{ @@ -352,13 +361,14 @@ func TestLogin(t *testing.T) { //nolint:gocyclo } tests := []struct { - name string - opt func(t *testing.T) Option - issuer string - clientID string - wantErr string - wantToken *oidctypes.Token - wantLogs []string + name string + opt func(t *testing.T) Option + issuer string + clientID string + wantErr string + wantToken *oidctypes.Token + wantLogs []string + wantStdErr string }{ { name: "option error", @@ -699,6 +709,9 @@ func TestLogin(t *testing.T) { //nolint:gocyclo name: "listening disabled and manual prompt fails", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.generateState = func() (state.State, error) { return "test-state", nil } + h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } + h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h)) require.NoError(t, WithSkipListen()(h)) h.isTTY = func(fd int) bool { return true } @@ -709,7 +722,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo require.Equal(t, "form_post", parsed.Query().Get("response_mode")) return fmt.Errorf("some browser open error") } - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { return "", fmt.Errorf("some prompt error") } return nil @@ -720,12 +733,24 @@ func TestLogin(t *testing.T) { //nolint:gocyclo `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, `"msg"="could not open browser" "error"="some browser open error"`, }, + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+ + "&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n[...]\n\n") + + "$", wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", }, { name: "listen success and manual prompt succeeds", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.generateState = func() (state.State, error) { return "test-state", nil } + h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } + h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h)) h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") } h.isTTY = func(fd int) bool { return true } @@ -736,7 +761,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo require.Equal(t, "form_post", parsed.Query().Get("response_mode")) return nil } - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { return "", fmt.Errorf("some prompt error") } return nil @@ -747,12 +772,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`, `"msg"="could not open callback listener" "error"="some listen error"`, }, + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+ + "&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n[...]\n\n") + + "$", wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error", }, { name: "timeout waiting for callback", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.generateState = func() (state.State, error) { return "test-state", nil } + h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } + h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } + require.NoError(t, WithClient(newClientForServer(successServer))(h)) ctx, cancel := context.WithCancel(h.ctx) @@ -767,12 +805,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo }, issuer: successServer.URL, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "timed out waiting for token callback: context canceled", + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") + + "[0-9]+" + // random port + regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n") + + "$", + wantErr: "timed out waiting for token callback: context canceled", }, { name: "callback returns error", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.generateState = func() (state.State, error) { return "test-state", nil } + h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } + h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } require.NoError(t, WithClient(newClientForServer(successServer))(h)) h.openURL = func(_ string) error { go func() { @@ -785,7 +836,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo }, issuer: successServer.URL, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, - wantErr: "error handling callback: some callback error", + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") + + "[0-9]+" + // random port + regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n") + + "$", + wantErr: "error handling callback: some callback error", }, { name: "callback returns success", @@ -823,10 +884,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo actualParams.Del("redirect_uri") require.Equal(t, url.Values{ - // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: - // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 - // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g - "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge": []string{testCodeChallenge}, "code_challenge_method": []string{"S256"}, "response_type": []string{"code"}, "scope": []string{"test-scope"}, @@ -847,8 +905,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo return nil } }, - issuer: successServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, + issuer: successServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") + + "[0-9]+" + // random port + regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n") + + "$", wantToken: &testToken, }, { @@ -887,10 +955,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo actualParams.Del("redirect_uri") require.Equal(t, url.Values{ - // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: - // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 - // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g - "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge": []string{testCodeChallenge}, "code_challenge_method": []string{"S256"}, "response_type": []string{"code"}, "response_mode": []string{"form_post"}, @@ -912,8 +977,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo return nil } }, - issuer: formPostSuccessServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""}, + issuer: formPostSuccessServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""}, + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") + + "[0-9]+" + // random port + regexp.QuoteMeta("%2Fcallback&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n") + + "$", wantToken: &testToken, }, { @@ -954,10 +1029,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo actualParams.Del("redirect_uri") require.Equal(t, url.Values{ - // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: - // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 - // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g - "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge": []string{testCodeChallenge}, "code_challenge_method": []string{"S256"}, "response_type": []string{"code"}, "scope": []string{"test-scope"}, @@ -980,8 +1052,19 @@ func TestLogin(t *testing.T) { //nolint:gocyclo return nil } }, - issuer: successServer.URL, - wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, + issuer: successServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, + wantStdErr: "^" + + regexp.QuoteMeta("Log in by visiting this link:\n\n") + + regexp.QuoteMeta(" https://127.0.0.1:") + + "[0-9]+" + // random port + regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+ + "&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=oidc"+ + "&redirect_uri=http%3A%2F%2F127.0.0.1%3A") + + "[0-9]+" + // random port + regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") + + regexp.QuoteMeta("\n\n") + + "$", wantToken: &testToken, }, { @@ -990,7 +1073,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { require.Equal(t, "Username: ", promptLabel) return "", errors.New("some prompt error") } @@ -1007,7 +1090,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo opt: func(t *testing.T) Option { return func(h *handlerState) error { _ = defaultLDAPTestOpts(t, h, nil, nil) - h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") } + h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "", errors.New("some prompt error") } return nil } }, @@ -1069,7 +1152,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo issuer: successServer.URL, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantErr: `authorization response error: Get "https://` + successServer.Listener.Addr().String() + - `/authorize?access_type=offline&client_id=test-client-id&code_challenge=VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code&scope=test-scope&state=test-state": some error fetching authorize endpoint`, + `/authorize?access_type=offline&client_id=test-client-id&code_challenge=` + testCodeChallenge + + `&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&` + + `pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code` + + `&scope=test-scope&state=test-state": some error fetching authorize endpoint`, }, { name: "ldap login when the OIDC provider authorization endpoint returns something other than a redirect", @@ -1198,11 +1284,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo h.getEnv = func(_ string) string { return "" // asking for any env var returns empty as if it were unset } - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { require.Equal(t, "Username: ", promptLabel) return "some-upstream-username", nil } - h.promptForSecret = func(promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) { require.Equal(t, "Password: ", promptLabel) return "some-upstream-password", nil } @@ -1242,10 +1328,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, url.Values{ - // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: - // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 - // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g - "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge": []string{testCodeChallenge}, "code_challenge_method": []string{"S256"}, "response_type": []string{"code"}, "scope": []string{"test-scope"}, @@ -1306,22 +1389,21 @@ func TestLogin(t *testing.T) { //nolint:gocyclo return "" // all other env vars are treated as if they are unset } } - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } - h.promptForSecret = func(promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) { require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } cache := &mockSessionCache{t: t, getReturnsToken: nil} cacheKey := SessionCacheKey{ - Issuer: successServer.URL, - ClientID: "test-client-id", - Scopes: []string{"test-scope"}, - RedirectURI: "http://localhost:0/callback", - UpstreamProviderName: "some-upstream-name", + Issuer: successServer.URL, + ClientID: "test-client-id", + Scopes: []string{"test-scope"}, + RedirectURI: "http://localhost:0/callback", } t.Cleanup(func() { require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) @@ -1330,7 +1412,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo }) require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithCLISendingCredentials()(h)) - require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h)) discoveryRequestWasMade := false authorizeRequestWasMade := false @@ -1350,10 +1431,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, url.Values{ - // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: - // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 - // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g - "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge": []string{testCodeChallenge}, "code_challenge_method": []string{"S256"}, "response_type": []string{"code"}, "scope": []string{"test-scope"}, @@ -1362,8 +1440,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo "access_type": []string{"offline"}, "client_id": []string{"test-client-id"}, "redirect_uri": []string{"http://127.0.0.1:0/callback"}, - "pinniped_idp_name": []string{"some-upstream-name"}, - "pinniped_idp_type": []string{"ldap"}, }, req.URL.Query()) return &http.Response{ StatusCode: http.StatusFound, @@ -1418,11 +1494,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo return "" // all other env vars are treated as if they are unset } } - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } - h.promptForSecret = func(promptLabel string) (string, error) { + h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) { require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel)) return "", nil } @@ -1462,10 +1538,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username")) require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password")) require.Equal(t, url.Values{ - // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: - // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 - // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g - "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge": []string{testCodeChallenge}, "code_challenge_method": []string{"S256"}, "response_type": []string{"code"}, "scope": []string{"test-scope"}, @@ -1898,6 +1971,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements klog.SetLogger(testLogger.Logger) + buffer := bytes.Buffer{} tok, err := Login(tt.issuer, tt.clientID, WithContext(context.Background()), WithListenPort(0), @@ -1905,8 +1979,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo WithSkipBrowserOpen(), tt.opt(t), WithLogger(testLogger.Logger), + withOutWriter(t, &buffer), ) + testLogger.Expect(tt.wantLogs) + + if tt.wantStdErr == "" { + require.Empty(t, buffer.String()) + } else { + require.Regexp(t, tt.wantStdErr, buffer.String()) + } + if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) require.Nil(t, tok) @@ -1940,6 +2023,15 @@ func TestLogin(t *testing.T) { //nolint:gocyclo } } +func withOutWriter(t *testing.T, out io.Writer) Option { + return func(h *handlerState) error { + // Ensure that the proper default value has been set in the handlerState prior to overriding it for tests. + require.Equal(t, os.Stderr, h.out) + h.out = out + return nil + } +} + func TestHandlePasteCallback(t *testing.T) { const testRedirectURI = "http://127.0.0.1:12324/callback" @@ -1977,7 +2069,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel) return "", fmt.Errorf("some prompt error") } @@ -1994,7 +2086,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { return "invalid", nil } h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} @@ -2018,7 +2110,7 @@ func TestHandlePasteCallback(t *testing.T) { return func(h *handlerState) error { h.isTTY = func(fd int) bool { return true } h.useFormPost = true - h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { + h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) { return "valid", nil } h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} @@ -2042,11 +2134,13 @@ func TestHandlePasteCallback(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + buf := &bytes.Buffer{} h := &handlerState{ callbacks: make(chan callbackResult, 1), state: state.State("test-state"), pkce: pkce.Code("test-pkce"), nonce: nonce.Nonce("test-nonce"), + out: buf, } if tt.opt != nil { require.NoError(t, tt.opt(t)(h)) @@ -2054,8 +2148,7 @@ func TestHandlePasteCallback(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - var buf bytes.Buffer - h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf) + h.promptForWebLogin(ctx, "https://test-authorize-url/") require.Equal(t, "Log in by visiting this link:\n\n https://test-authorize-url/\n\n", buf.String(),