backfill unit tests for expected stderr output in login_test.go

Co-authored-by: Joshua Casey <joshuatcasey@gmail.com>
This commit is contained in:
Ryan Richard 2023-10-09 13:44:42 -07:00
parent 6ee1e35329
commit 3a21c9a35b
2 changed files with 175 additions and 80 deletions

View File

@ -78,6 +78,7 @@ type handlerState struct {
clientID string
scopes []string
cache SessionCache
out io.Writer
upstreamIdentityProviderName string
upstreamIdentityProviderType string
@ -109,8 +110,8 @@ type handlerState struct {
isTTY func(int) bool
getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error)
promptForValue func(ctx context.Context, promptLabel string, out io.Writer) (string, error)
promptForSecret func(promptLabel string, out io.Writer) (string, error)
callbacks chan callbackResult
}
@ -292,6 +293,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
},
promptForValue: promptForValue,
promptForSecret: promptForSecret,
out: os.Stderr,
}
for _, opt := range opts {
if err := opt(&h); err != nil {
@ -513,7 +515,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
username := h.getEnv(defaultUsernameEnvVarName)
if username == "" {
username, err = h.promptForValue(h.ctx, usernamePrompt)
username, err = h.promptForValue(h.ctx, usernamePrompt, h.out)
if err != nil {
return "", "", fmt.Errorf("error prompting for username: %w", err)
}
@ -523,7 +525,7 @@ func (h *handlerState) getUsernameAndPassword() (string, string, error) {
password := h.getEnv(defaultPasswordEnvVarName)
if password == "" {
password, err = h.promptForSecret(passwordPrompt)
password, err = h.promptForSecret(passwordPrompt, h.out)
if err != nil {
return "", "", fmt.Errorf("error prompting for password: %w", err)
}
@ -581,7 +583,7 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
ctx, cancel := context.WithCancel(h.ctx)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
cleanupPrompt := h.promptForWebLogin(ctx, authorizeURL)
defer func() {
cancel()
cleanupPrompt()
@ -599,8 +601,8 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
}
}
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) func() {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string) func() {
_, _ = fmt.Fprintf(h.out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it.
@ -621,15 +623,15 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
go func() {
defer func() {
// Always emit a newline so the kubectl output is visually separated from the login prompts.
_, _ = fmt.Fprintln(os.Stderr)
_, _ = fmt.Fprintln(h.out)
wg.Done()
}()
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ")
code, err := h.promptForValue(ctx, " Optionally, paste your authorization code: ", h.out)
if err != nil {
// Print a visual marker to show the the prompt is no longer waiting for user input, plus a trailing
// newline that simulates the user having pressed "enter".
_, _ = fmt.Fprint(os.Stderr, "[...]\n")
_, _ = fmt.Fprint(h.out, "[...]\n")
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return
@ -642,11 +644,11 @@ func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL strin
return wg.Wait
}
func promptForValue(ctx context.Context, promptLabel string) (string, error) {
func promptForValue(ctx context.Context, promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal")
}
_, err := fmt.Fprint(os.Stderr, promptLabel)
_, err := fmt.Fprint(out, promptLabel)
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
@ -674,11 +676,11 @@ func promptForValue(ctx context.Context, promptLabel string) (string, error) {
}
}
func promptForSecret(promptLabel string) (string, error) {
func promptForSecret(promptLabel string, out io.Writer) (string, error) {
if !term.IsTerminal(stdin()) {
return "", errors.New("stdin is not connected to a terminal")
}
_, err := fmt.Fprint(os.Stderr, promptLabel)
_, err := fmt.Fprint(out, promptLabel)
if err != nil {
return "", fmt.Errorf("could not print prompt to stderr: %w", err)
}
@ -689,7 +691,7 @@ func promptForSecret(promptLabel string) (string, error) {
// term.ReadPassword swallows the newline that was typed by the user, so to
// avoid the next line of output from happening on same line as the password
// prompt, we need to print a newline.
_, err = fmt.Fprint(os.Stderr, "\n")
_, err = fmt.Fprint(out, "\n")
if err != nil {
return "", fmt.Errorf("could not print newline to stderr: %w", err)
}

View File

@ -15,6 +15,8 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"regexp"
"strings"
"syscall"
"testing"
@ -77,6 +79,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
const testCodeChallenge = "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
testToken := oidctypes.Token{
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
@ -316,8 +323,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) { return "some-upstream-username", nil }
h.promptForSecret = func(_ string) (string, error) { return "some-upstream-password", nil }
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "some-upstream-username", nil
}
h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "some-upstream-password", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
@ -352,13 +361,14 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}
tests := []struct {
name string
opt func(t *testing.T) Option
issuer string
clientID string
wantErr string
wantToken *oidctypes.Token
wantLogs []string
name string
opt func(t *testing.T) Option
issuer string
clientID string
wantErr string
wantToken *oidctypes.Token
wantLogs []string
wantStdErr string
}{
{
name: "option error",
@ -699,6 +709,9 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
name: "listening disabled and manual prompt fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h))
require.NoError(t, WithSkipListen()(h))
h.isTTY = func(fd int) bool { return true }
@ -709,7 +722,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error")
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
@ -720,12 +733,24 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open browser" "error"="some browser open error"`,
},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+
"&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n[...]\n\n") +
"$",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "listen success and manual prompt succeeds",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(formPostSuccessServer))(h))
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(fd int) bool { return true }
@ -736,7 +761,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
@ -747,12 +772,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback"+
"&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n[...]\n\n") +
"$",
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "timeout waiting for callback",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(successServer))(h))
ctx, cancel := context.WithCancel(h.ctx)
@ -767,12 +805,25 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "timed out waiting for token callback: context canceled",
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantErr: "timed out waiting for token callback: context canceled",
},
{
name: "callback returns error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
require.NoError(t, WithClient(newClientForServer(successServer))(h))
h.openURL = func(_ string) error {
go func() {
@ -785,7 +836,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: "error handling callback: some callback error",
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantErr: "error handling callback: some callback error",
},
{
name: "callback returns success",
@ -823,10 +884,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge": []string{testCodeChallenge},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
@ -847,8 +905,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken,
},
{
@ -887,10 +955,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge": []string{testCodeChallenge},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"response_mode": []string{"form_post"},
@ -912,8 +977,18 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_mode=form_post&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken,
},
{
@ -954,10 +1029,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge": []string{testCodeChallenge},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
@ -980,8 +1052,19 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return nil
}
},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantStdErr: "^" +
regexp.QuoteMeta("Log in by visiting this link:\n\n") +
regexp.QuoteMeta(" https://127.0.0.1:") +
"[0-9]+" + // random port
regexp.QuoteMeta("/authorize?access_type=offline&client_id=test-client-id&code_challenge="+testCodeChallenge+
"&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=oidc"+
"&redirect_uri=http%3A%2F%2F127.0.0.1%3A") +
"[0-9]+" + // random port
regexp.QuoteMeta("%2Fcallback&response_type=code&scope=test-scope&state=test-state") +
regexp.QuoteMeta("\n\n") +
"$",
wantToken: &testToken,
},
{
@ -990,7 +1073,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "", errors.New("some prompt error")
}
@ -1007,7 +1090,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
_ = defaultLDAPTestOpts(t, h, nil, nil)
h.promptForSecret = func(_ string) (string, error) { return "", errors.New("some prompt error") }
h.promptForSecret = func(_ string, _ io.Writer) (string, error) { return "", errors.New("some prompt error") }
return nil
}
},
@ -1069,7 +1152,10 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantErr: `authorization response error: Get "https://` + successServer.Listener.Addr().String() +
`/authorize?access_type=offline&client_id=test-client-id&code_challenge=VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code&scope=test-scope&state=test-state": some error fetching authorize endpoint`,
`/authorize?access_type=offline&client_id=test-client-id&code_challenge=` + testCodeChallenge +
`&code_challenge_method=S256&nonce=test-nonce&pinniped_idp_name=some-upstream-name&` +
`pinniped_idp_type=ldap&redirect_uri=http%3A%2F%2F127.0.0.1%3A0%2Fcallback&response_type=code` +
`&scope=test-scope&state=test-state": some error fetching authorize endpoint`,
},
{
name: "ldap login when the OIDC provider authorization endpoint returns something other than a redirect",
@ -1198,11 +1284,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
h.getEnv = func(_ string) string {
return "" // asking for any env var returns empty as if it were unset
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Username: ", promptLabel)
return "some-upstream-username", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.Equal(t, "Password: ", promptLabel)
return "some-upstream-password", nil
}
@ -1242,10 +1328,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge": []string{testCodeChallenge},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
@ -1306,22 +1389,21 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset
}
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
UpstreamProviderName: "some-upstream-name",
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
@ -1330,7 +1412,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithCLISendingCredentials()(h))
require.NoError(t, WithUpstreamIdentityProvider("some-upstream-name", "ldap")(h))
discoveryRequestWasMade := false
authorizeRequestWasMade := false
@ -1350,10 +1431,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge": []string{testCodeChallenge},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
@ -1362,8 +1440,6 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://127.0.0.1:0/callback"},
"pinniped_idp_name": []string{"some-upstream-name"},
"pinniped_idp_type": []string{"ldap"},
}, req.URL.Query())
return &http.Response{
StatusCode: http.StatusFound,
@ -1418,11 +1494,11 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return "" // all other env vars are treated as if they are unset
}
}
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
h.promptForSecret = func(promptLabel string) (string, error) {
h.promptForSecret = func(promptLabel string, _ io.Writer) (string, error) {
require.FailNow(t, fmt.Sprintf("saw unexpected prompt from the CLI: %q", promptLabel))
return "", nil
}
@ -1462,10 +1538,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
require.Equal(t, "some-upstream-username", req.Header.Get("Pinniped-Username"))
require.Equal(t, "some-upstream-password", req.Header.Get("Pinniped-Password"))
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge": []string{testCodeChallenge},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
@ -1898,6 +1971,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
testLogger := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements
klog.SetLogger(testLogger.Logger)
buffer := bytes.Buffer{}
tok, err := Login(tt.issuer, tt.clientID,
WithContext(context.Background()),
WithListenPort(0),
@ -1905,8 +1979,17 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
WithSkipBrowserOpen(),
tt.opt(t),
WithLogger(testLogger.Logger),
withOutWriter(t, &buffer),
)
testLogger.Expect(tt.wantLogs)
if tt.wantStdErr == "" {
require.Empty(t, buffer.String())
} else {
require.Regexp(t, tt.wantStdErr, buffer.String())
}
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok)
@ -1940,6 +2023,15 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}
}
func withOutWriter(t *testing.T, out io.Writer) Option {
return func(h *handlerState) error {
// Ensure that the proper default value has been set in the handlerState prior to overriding it for tests.
require.Equal(t, os.Stderr, h.out)
h.out = out
return nil
}
}
func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
@ -1977,7 +2069,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
assert.Equal(t, " Optionally, paste your authorization code: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
@ -1994,7 +2086,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -2018,7 +2110,7 @@ func TestHandlePasteCallback(t *testing.T) {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForValue = func(_ context.Context, promptLabel string) (string, error) {
h.promptForValue = func(_ context.Context, promptLabel string, _ io.Writer) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -2042,11 +2134,13 @@ func TestHandlePasteCallback(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
buf := &bytes.Buffer{}
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
out: buf,
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
@ -2054,8 +2148,7 @@ func TestHandlePasteCallback(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var buf bytes.Buffer
h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
h.promptForWebLogin(ctx, "https://test-authorize-url/")
require.Equal(t,
"Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
buf.String(),