Add CORS request handling to CLI's localhost listener

This is to support the new changes in Google Chrome v98 which now
performs CORS preflight requests for the Javascript form submission
on the Supervisor's login page, even though the form is being submitted
to a localhost listener.
This commit is contained in:
Ryan Richard 2022-02-04 16:57:37 -08:00
parent 7c246784dc
commit 7b97f1533e
2 changed files with 143 additions and 26 deletions

View File

@ -834,10 +834,41 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}()
var params url.Values
if h.useFormPost {
if h.useFormPost { // nolint:nestif
if r.Method == http.MethodOptions {
// Google Chrome decided that it should do CORS preflight checks for this Javascript form submission POST request.
// See https://developer.chrome.com/blog/private-network-access-preflight/
origin := r.Header.Get("Origin")
if origin == "" {
// The CORS preflight request should have an origin.
h.logger.V(debugLogLevel).Info("Pinniped: Got OPTIONS request without origin header")
w.WriteHeader(http.StatusBadRequest)
return nil // keep listening for more requests
}
h.logger.V(debugLogLevel).Info("Pinniped: Got CORS preflight request from browser", "origin", origin)
issuerURL, parseErr := url.Parse(h.issuer)
if parseErr != nil {
return httperr.Wrap(http.StatusInternalServerError, "invalid issuer url", parseErr)
}
// To tell the browser that it is okay to make the real POST request, return the following response.
w.Header().Set("Access-Control-Allow-Origin", issuerURL.Scheme+"://"+issuerURL.Host)
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Private-Network", "true")
// If the browser would like to send some headers on the real request, allow them. Chrome doesn't
// currently send this header at the moment. This is in case some browser in the future decides to
// request to be allowed to send specific headers by using Access-Control-Request-Headers.
requestedHeaders := r.Header.Get("Access-Control-Request-Headers")
if requestedHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestedHeaders)
}
w.WriteHeader(http.StatusNoContent)
return nil // keep listening for more requests
}
// Return HTTP 405 for anything that's not a POST.
if r.Method != http.MethodPost {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST")
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST but got %s", r.Method)
}
// Parse and pull the response parameters from a application/x-www-form-urlencoded request body.
@ -848,7 +879,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
} else {
// Return HTTP 405 for anything that's not a GET.
if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET but got %s", r.Method)
}
// Pull response parameters from the URL query string.

View File

@ -1825,6 +1825,8 @@ func TestHandlePasteCallback(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
@ -1866,35 +1868,38 @@ func TestHandleAuthCodeCallback(t *testing.T) {
}
}
tests := []struct {
name string
method string
query string
body []byte
contentType string
opt func(t *testing.T) Option
wantErr string
wantHTTPStatus int
name string
method string
query string
body []byte
headers http.Header
opt func(t *testing.T) Option
wantErr string
wantHTTPStatus int
wantNoCallbacks bool
wantHeaders http.Header
}{
{
name: "wrong method",
method: "POST",
method: http.MethodPost,
query: "",
wantErr: "wanted GET",
wantErr: "wanted GET but got POST",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "wrong method for form_post",
method: "GET",
method: http.MethodGet,
query: "",
opt: withFormPostMode,
wantErr: "wanted POST",
wantErr: "wanted POST but got GET",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "invalid form for form_post",
method: "POST",
method: http.MethodPost,
query: "",
contentType: "application/x-www-form-urlencoded",
headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
body: []byte(`%`),
opt: withFormPostMode,
wantErr: `invalid form: invalid URL escape "%"`,
@ -1918,6 +1923,75 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantErr: `login failed with code "some_error": optional error description`,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "in form post mode, invalid issuer url config during CORS preflight request returns an error",
method: http.MethodOptions,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
wantErr: `invalid issuer url: parse "://bad-url": missing protocol scheme`,
wantHTTPStatus: http.StatusInternalServerError,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "://bad-url"
return nil
}
},
},
{
name: "in form post mode, options request is missing origin header results in 400 and keeps listener running",
method: http.MethodOptions,
query: "",
opt: withFormPostMode,
wantNoCallbacks: true,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "in form post mode, valid CORS request responds with 402 and CORS headers and keeps listener running",
method: http.MethodOptions,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
wantNoCallbacks: true,
wantHTTPStatus: http.StatusNoContent,
wantHeaders: map[string][]string{
"Access-Control-Allow-Credentials": {"false"},
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Access-Control-Allow-Private-Network": {"true"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "https://valid-issuer.com/with/some/path"
return nil
}
},
},
{
name: "in form post mode, valid CORS request with Access-Control-Request-Headers responds with 402 and CORS headers including Access-Control-Allow-Headers and keeps listener running",
method: http.MethodOptions,
query: "",
headers: map[string][]string{
"Origin": {"https://some-origin.com"},
"Access-Control-Request-Headers": {"header1, header2, header3"},
},
wantNoCallbacks: true,
wantHTTPStatus: http.StatusNoContent,
wantHeaders: map[string][]string{
"Access-Control-Allow-Credentials": {"false"},
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Access-Control-Allow-Private-Network": {"true"},
"Access-Control-Allow-Headers": {"header1, header2, header3"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "https://valid-issuer.com/with/some/path"
return nil
}
},
},
{
name: "invalid code",
query: "state=test-state&code=invalid",
@ -1938,8 +2012,9 @@ func TestHandleAuthCodeCallback(t *testing.T) {
},
},
{
name: "valid",
query: "state=test-state&code=valid",
name: "valid",
query: "state=test-state&code=valid",
wantHTTPStatus: http.StatusOK,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -1955,10 +2030,11 @@ func TestHandleAuthCodeCallback(t *testing.T) {
},
},
{
name: "valid form_post",
method: http.MethodPost,
contentType: "application/x-www-form-urlencoded",
body: []byte(`state=test-state&code=valid`),
name: "valid form_post",
method: http.MethodPost,
headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
body: []byte(`state=test-state&code=valid`),
wantHTTPStatus: http.StatusOK,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
@ -1978,11 +2054,14 @@ func TestHandleAuthCodeCallback(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
logger: testlogger.New(t).Logger,
}
if tt.opt != nil {
require.NoError(t, tt.opt(t)(h))
@ -1998,8 +2077,8 @@ func TestHandleAuthCodeCallback(t *testing.T) {
if tt.method != "" {
req.Method = tt.method
}
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
if tt.headers != nil {
req.Header = tt.headers
}
err = h.handleAuthCodeCallback(resp, req)
@ -2012,11 +2091,18 @@ func TestHandleAuthCodeCallback(t *testing.T) {
}
} else {
require.NoError(t, err)
require.Equal(t, tt.wantHTTPStatus, resp.Code)
}
if tt.wantHeaders != nil {
require.Equal(t, tt.wantHeaders, resp.Header())
}
select {
case <-time.After(1 * time.Second):
require.Fail(t, "timed out waiting to receive from callbacks channel")
if !tt.wantNoCallbacks {
require.Fail(t, "timed out waiting to receive from callbacks channel")
}
case result := <-h.callbacks:
if tt.wantErr != "" {
require.EqualError(t, result.err, tt.wantErr)