Add manual paste flow to pinniped login oidc command.

This adds a new login flow that allows manually pasting the authorization code instead of receiving a browser-based callback.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2021-07-08 14:32:44 -05:00
parent ac6ff1a03c
commit 5029495fdb
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
4 changed files with 302 additions and 56 deletions

View File

@ -10,6 +10,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"mime" "mime"
"net" "net"
"net/http" "net/http"
@ -98,6 +99,8 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
listen func(string, string) (net.Listener, error)
isTTY func(int) bool
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error) promptForValue func(ctx context.Context, promptLabel string) (string, error)
@ -264,6 +267,8 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generateNonce: nonce.Generate, generateNonce: nonce.Generate,
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
listen: net.Listen,
isTTY: term.IsTerminal,
getProvider: upstreamoidc.New, getProvider: upstreamoidc.New,
validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) { validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token) return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token)
@ -489,16 +494,27 @@ func (h *handlerState) cliBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (
// Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint. // Open a web browser, or ask the user to open a web browser, to visit the authorize endpoint.
// Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error. // Create a localhost callback listener which exchanges the authcode for tokens. Return the tokens or an error.
func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) { func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOption) (*oidctypes.Token, error) {
// Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number). // Attempt to open a local TCP listener, logging but otherwise ignoring any error.
listener, err := net.Listen("tcp", h.listenAddr) listener, err := h.listen("tcp", h.listenAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err) h.logger.V(debugLogLevel).Error(err, "could not open callback listener")
} }
h.oauth2Config.RedirectURL = (&url.URL{
Scheme: "http", // If the listener failed to start and stdin is not a TTY, then we have no hope of succeeding,
Host: listener.Addr().String(), // since we won't be able to receive the web callback and we can't prompt for the manual auth code.
Path: h.callbackPath, if listener == nil && !h.isTTY(syscall.Stdin) {
}).String() return nil, fmt.Errorf("login failed: must have either a localhost listener or stdin must be a TTY")
}
// Update the OAuth2 redirect_uri to match the actual listener address (if there is one), or just use
// a fake ":0" port if there is no listener running.
redirectURI := url.URL{Scheme: "http", Path: h.callbackPath}
if listener == nil {
redirectURI.Host = "127.0.0.1:0"
} else {
redirectURI.Host = listener.Addr().String()
}
h.oauth2Config.RedirectURL = redirectURI.String()
// If the server supports it, request response_mode=form_post. // If the server supports it, request response_mode=form_post.
authParams := *authorizeOptions authParams := *authorizeOptions
@ -509,16 +525,24 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
// Now that we have a redirect URL with the listener port, we can build the authorize URL. // Now that we have a redirect URL with the listener port, we can build the authorize URL.
authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...) authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...)
// Start a callback server in a background goroutine. // If there is a listener running, start serving the callback handler in a background goroutine.
shutdown := h.serve(listener) if listener != nil {
defer shutdown() shutdown := h.serve(listener)
defer shutdown()
// Open the authorize URL in the users browser.
if err := h.openURL(authorizeURL); err != nil {
return nil, fmt.Errorf("could not open browser: %w", err)
} }
// Wait for either the callback or a timeout. // Open the authorize URL in the users browser, logging but otherwise ignoring any error.
if err := h.openURL(authorizeURL); err != nil {
h.logger.V(debugLogLevel).Error(err, "could not open browser")
}
ctx, cancel := context.WithCancel(h.ctx)
defer cancel()
// Prompt the user to visit the authorize URL, and to paste a manually-copied auth code (if possible).
h.promptForWebLogin(ctx, authorizeURL, os.Stderr)
// Wait for either the web callback, a pasted auth code, or a timeout.
select { select {
case <-h.ctx.Done(): case <-h.ctx.Done():
return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err()) return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err())
@ -530,6 +554,36 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
} }
} }
func (h *handlerState) promptForWebLogin(ctx context.Context, authorizeURL string, out io.Writer) {
_, _ = fmt.Fprintf(out, "Log in by visiting this link:\n\n %s\n\n", authorizeURL)
// If stdin is not a TTY, print the URL but don't prompt for the manual paste,
// since we have no way of reading it.
if !h.isTTY(syscall.Stdin) {
return
}
// If the server didn't support response_mode=form_post, don't bother prompting for the manual
// code because the user isn't going to have any easy way to manually copy it anyway.
if !h.useFormPost {
return
}
// Launch the manual auth code prompt in a background goroutine, which will be cancelled
// if the parent context is cancelled (when the login succeeds or times out).
go func() {
code, err := h.promptForSecret(ctx, " If automatic login fails, paste your authorization code to login manually: ")
if err != nil {
h.callbacks <- callbackResult{err: fmt.Errorf("failed to prompt for manual authorization code: %v", err)}
return
}
// When a code is pasted, redeem it for a token and return that result on the callbacks channel.
token, err := h.redeemAuthCode(ctx, code)
h.callbacks <- callbackResult{token: token, err: err}
}()
}
func promptForValue(ctx context.Context, promptLabel string) (string, error) { func promptForValue(ctx context.Context, promptLabel string) (string, error) {
if !term.IsTerminal(int(os.Stdin.Fd())) { if !term.IsTerminal(int(os.Stdin.Fd())) {
return "", errors.New("stdin is not connected to a terminal") return "", errors.New("stdin is not connected to a terminal")
@ -758,14 +812,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
// Exchange the authorization code for access, ID, and refresh tokens and perform required // Exchange the authorization code for access, ID, and refresh tokens and perform required
// validations on the returned ID token. // validations on the returned ID token.
token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). token, err := h.redeemAuthCode(r.Context(), params.Get("code"))
ExchangeAuthcodeAndValidateTokens(
r.Context(),
params.Get("code"),
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }
@ -775,6 +822,17 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
return nil return nil
} }
func (h *handlerState) redeemAuthCode(ctx context.Context, code string) (*oidctypes.Token, error) {
return h.getProvider(h.oauth2Config, h.provider, h.httpClient).
ExchangeAuthcodeAndValidateTokens(
ctx,
code,
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
}
func (h *handlerState) serve(listener net.Listener) func() { func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))

View File

@ -10,10 +10,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"syscall"
"testing" "testing"
"time" "time"
@ -490,38 +492,94 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
}) })
h.cache = cache h.cache = cache
h.listenAddr = "invalid-listen-address" h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(int) bool { return false }
return nil return nil
} }
}, },
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\"", wantLogs: []string{
"\"level\"=4 \"msg\"=\"Pinniped: Refreshing cached token.\""}, `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
`"level"=4 "msg"="Pinniped: Refreshing cached token."`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
// Expect this to fall through to the authorization code flow, so it fails here. // Expect this to fall through to the authorization code flow, so it fails here.
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
}, },
{ {
name: "listen failure", name: "listen failure and non-tty stdin",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.listenAddr = "invalid-listen-address" h.listen = func(net string, addr string) (net.Listener, error) {
assert.Equal(t, "tcp", net)
assert.Equal(t, "localhost:0", addr)
return nil, fmt.Errorf("some listen error")
}
h.isTTY = func(fd int) bool {
assert.Equal(t, fd, syscall.Stdin)
return false
}
return nil return nil
} }
}, },
issuer: successServer.URL, issuer: successServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + successServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantErr: "login failed: must have either a localhost listener or stdin must be a TTY",
}, },
{ {
name: "browser open failure", name: "listen failure and manual prompt fails",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return WithBrowserOpen(func(url string) error { return func(h *handlerState) error {
return fmt.Errorf("some browser open error") h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
}) h.isTTY = func(fd int) bool { return true }
h.openURL = func(authorizeURL string) error {
parsed, err := url.Parse(authorizeURL)
require.NoError(t, err)
require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return fmt.Errorf("some browser open error")
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
}
}, },
issuer: successServer.URL, issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{
wantErr: "could not open browser: some browser open error", `"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
`"msg"="could not open browser" "error"="some browser open error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
},
{
name: "listen success and manual prompt succeeds",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.listen = func(string, string) (net.Listener, error) { return nil, fmt.Errorf("some listen error") }
h.isTTY = func(fd int) bool { return true }
h.openURL = func(authorizeURL string) error {
parsed, err := url.Parse(authorizeURL)
require.NoError(t, err)
require.Equal(t, "http://127.0.0.1:0/callback", parsed.Query().Get("redirect_uri"))
require.Equal(t, "form_post", parsed.Query().Get("response_mode"))
return nil
}
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{
`"level"=4 "msg"="Pinniped: Performing OIDC discovery" "issuer"="` + formPostSuccessServer.URL + `"`,
`"msg"="could not open callback listener" "error"="some listen error"`,
},
wantErr: "error handling callback: failed to prompt for manual authorization code: some prompt error",
}, },
{ {
name: "timeout waiting for callback", name: "timeout waiting for callback",
@ -1388,10 +1446,11 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
WithContext(context.Background()), WithContext(context.Background()),
WithListenPort(0), WithListenPort(0),
WithScopes([]string{"test-scope"}), WithScopes([]string{"test-scope"}),
WithSkipBrowserOpen(),
tt.opt(t), tt.opt(t),
WithLogger(testLogger), WithLogger(testLogger),
) )
require.Equal(t, tt.wantLogs, testLogger.Lines()) testLogger.Expect(tt.wantLogs)
if tt.wantErr != "" { if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr) require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok) require.Nil(t, tok)
@ -1425,6 +1484,137 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
} }
} }
func TestHandlePasteCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
tests := []struct {
name string
opt func(t *testing.T) Option
wantCallback *callbackResult
}{
{
name: "no stdin available",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool {
require.Equal(t, syscall.Stdin, fd)
return false
}
h.useFormPost = true
return nil
}
},
},
{
name: "no form_post mode available",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = false
return nil
}
},
},
{
name: "prompt fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
assert.Equal(t, " If automatic login fails, paste your authorization code to login manually: ", promptLabel)
return "", fmt.Errorf("some prompt error")
}
return nil
}
},
wantCallback: &callbackResult{
err: fmt.Errorf("failed to prompt for manual authorization code: some prompt error"),
},
},
{
name: "redeeming code fails",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "invalid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(nil, fmt.Errorf("some exchange error"))
return mock
}
return nil
}
},
wantCallback: &callbackResult{
err: fmt.Errorf("some exchange error"),
},
},
{
name: "success",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.isTTY = func(fd int) bool { return true }
h.useFormPost = true
h.promptForSecret = func(ctx context.Context, promptLabel string) (string, error) {
return "valid", nil
}
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
wantCallback: &callbackResult{
token: &oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var buf bytes.Buffer
h.promptForWebLogin(ctx, "https://test-authorize-url/", &buf)
require.Equal(t,
"Log in by visiting this link:\n\n https://test-authorize-url/\n\n",
buf.String(),
)
if tt.wantCallback != nil {
select {
case <-time.After(1 * time.Second):
require.Fail(t, "timed out waiting to receive from callbacks channel")
case result := <-h.callbacks:
require.Equal(t, *tt.wantCallback, result)
}
}
})
}
}
func TestHandleAuthCodeCallback(t *testing.T) { func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback" const testRedirectURI = "http://127.0.0.1:12324/callback"

View File

@ -307,16 +307,15 @@ func runPinnipedLoginOIDC(
reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr)) reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderr))
scanner := bufio.NewScanner(reader) scanner := bufio.NewScanner(reader)
const prompt = "Please log in: "
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() loginURL, err := url.Parse(strings.TrimSpace(scanner.Text()))
if strings.HasPrefix(line, prompt) { if err == nil && loginURL.Scheme == "https" {
loginURLChan <- strings.TrimPrefix(line, prompt) loginURLChan <- loginURL.String()
return nil return nil
} }
} }
return fmt.Errorf("expected stderr to contain %s", prompt) return fmt.Errorf("expected stderr to contain login URL")
}) })
// Start a background goroutine to read stdout from the CLI and parse out an ExecCredential. // Start a background goroutine to read stdout from the CLI and parse out an ExecCredential.

View File

@ -195,16 +195,15 @@ func TestE2EFullIntegration(t *testing.T) {
}() }()
reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe)) reader := bufio.NewReader(testlib.NewLoggerReader(t, "stderr", stderrPipe))
line, err := reader.ReadString('\n') scanner := bufio.NewScanner(reader)
if err != nil { for scanner.Scan() {
return fmt.Errorf("could not read login URL line from stderr: %w", err) loginURL, err := url.Parse(strings.TrimSpace(scanner.Text()))
if err == nil && loginURL.Scheme == "https" {
loginURLChan <- loginURL.String()
return nil
}
} }
const prompt = "Please log in: " return fmt.Errorf("expected stderr to contain login URL")
if !strings.HasPrefix(line, prompt) {
return fmt.Errorf("expected %q to have prefix %q", line, prompt)
}
loginURLChan <- strings.TrimPrefix(line, prompt)
return readAndExpectEmpty(reader)
}) })
// Start a background goroutine to read stdout from kubectl and return the result as a string. // Start a background goroutine to read stdout from kubectl and return the result as a string.