Add CORS request handling to CLI's localhost listener

This is to support the new changes in Google Chrome v98 which now
performs CORS preflight requests for the Javascript form submission
on the Supervisor's login page, even though the form is being submitted
to a localhost listener.
This commit is contained in:
Ryan Richard 2022-02-04 16:57:37 -08:00
parent 7c246784dc
commit 7b97f1533e
2 changed files with 143 additions and 26 deletions

View File

@ -834,10 +834,41 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}() }()
var params url.Values var params url.Values
if h.useFormPost { if h.useFormPost { // nolint:nestif
if r.Method == http.MethodOptions {
// Google Chrome decided that it should do CORS preflight checks for this Javascript form submission POST request.
// See https://developer.chrome.com/blog/private-network-access-preflight/
origin := r.Header.Get("Origin")
if origin == "" {
// The CORS preflight request should have an origin.
h.logger.V(debugLogLevel).Info("Pinniped: Got OPTIONS request without origin header")
w.WriteHeader(http.StatusBadRequest)
return nil // keep listening for more requests
}
h.logger.V(debugLogLevel).Info("Pinniped: Got CORS preflight request from browser", "origin", origin)
issuerURL, parseErr := url.Parse(h.issuer)
if parseErr != nil {
return httperr.Wrap(http.StatusInternalServerError, "invalid issuer url", parseErr)
}
// To tell the browser that it is okay to make the real POST request, return the following response.
w.Header().Set("Access-Control-Allow-Origin", issuerURL.Scheme+"://"+issuerURL.Host)
w.Header().Set("Access-Control-Allow-Credentials", "false")
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Private-Network", "true")
// If the browser would like to send some headers on the real request, allow them. Chrome doesn't
// currently send this header at the moment. This is in case some browser in the future decides to
// request to be allowed to send specific headers by using Access-Control-Request-Headers.
requestedHeaders := r.Header.Get("Access-Control-Request-Headers")
if requestedHeaders != "" {
w.Header().Set("Access-Control-Allow-Headers", requestedHeaders)
}
w.WriteHeader(http.StatusNoContent)
return nil // keep listening for more requests
}
// Return HTTP 405 for anything that's not a POST. // Return HTTP 405 for anything that's not a POST.
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST") return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST but got %s", r.Method)
} }
// Parse and pull the response parameters from a application/x-www-form-urlencoded request body. // Parse and pull the response parameters from a application/x-www-form-urlencoded request body.
@ -848,7 +879,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
} else { } else {
// Return HTTP 405 for anything that's not a GET. // Return HTTP 405 for anything that's not a GET.
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET but got %s", r.Method)
} }
// Pull response parameters from the URL query string. // Pull response parameters from the URL query string.

View File

@ -1825,6 +1825,8 @@ func TestHandlePasteCallback(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
@ -1870,31 +1872,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
method string method string
query string query string
body []byte body []byte
contentType string headers http.Header
opt func(t *testing.T) Option opt func(t *testing.T) Option
wantErr string wantErr string
wantHTTPStatus int wantHTTPStatus int
wantNoCallbacks bool
wantHeaders http.Header
}{ }{
{ {
name: "wrong method", name: "wrong method",
method: "POST", method: http.MethodPost,
query: "", query: "",
wantErr: "wanted GET", wantErr: "wanted GET but got POST",
wantHTTPStatus: http.StatusMethodNotAllowed, wantHTTPStatus: http.StatusMethodNotAllowed,
}, },
{ {
name: "wrong method for form_post", name: "wrong method for form_post",
method: "GET", method: http.MethodGet,
query: "", query: "",
opt: withFormPostMode, opt: withFormPostMode,
wantErr: "wanted POST", wantErr: "wanted POST but got GET",
wantHTTPStatus: http.StatusMethodNotAllowed, wantHTTPStatus: http.StatusMethodNotAllowed,
}, },
{ {
name: "invalid form for form_post", name: "invalid form for form_post",
method: "POST", method: http.MethodPost,
query: "", query: "",
contentType: "application/x-www-form-urlencoded", headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
body: []byte(`%`), body: []byte(`%`),
opt: withFormPostMode, opt: withFormPostMode,
wantErr: `invalid form: invalid URL escape "%"`, wantErr: `invalid form: invalid URL escape "%"`,
@ -1918,6 +1923,75 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantErr: `login failed with code "some_error": optional error description`, wantErr: `login failed with code "some_error": optional error description`,
wantHTTPStatus: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
}, },
{
name: "in form post mode, invalid issuer url config during CORS preflight request returns an error",
method: http.MethodOptions,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
wantErr: `invalid issuer url: parse "://bad-url": missing protocol scheme`,
wantHTTPStatus: http.StatusInternalServerError,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "://bad-url"
return nil
}
},
},
{
name: "in form post mode, options request is missing origin header results in 400 and keeps listener running",
method: http.MethodOptions,
query: "",
opt: withFormPostMode,
wantNoCallbacks: true,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "in form post mode, valid CORS request responds with 402 and CORS headers and keeps listener running",
method: http.MethodOptions,
query: "",
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
wantNoCallbacks: true,
wantHTTPStatus: http.StatusNoContent,
wantHeaders: map[string][]string{
"Access-Control-Allow-Credentials": {"false"},
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Access-Control-Allow-Private-Network": {"true"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "https://valid-issuer.com/with/some/path"
return nil
}
},
},
{
name: "in form post mode, valid CORS request with Access-Control-Request-Headers responds with 402 and CORS headers including Access-Control-Allow-Headers and keeps listener running",
method: http.MethodOptions,
query: "",
headers: map[string][]string{
"Origin": {"https://some-origin.com"},
"Access-Control-Request-Headers": {"header1, header2, header3"},
},
wantNoCallbacks: true,
wantHTTPStatus: http.StatusNoContent,
wantHeaders: map[string][]string{
"Access-Control-Allow-Credentials": {"false"},
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
"Access-Control-Allow-Private-Network": {"true"},
"Access-Control-Allow-Headers": {"header1, header2, header3"},
},
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.issuer = "https://valid-issuer.com/with/some/path"
return nil
}
},
},
{ {
name: "invalid code", name: "invalid code",
query: "state=test-state&code=invalid", query: "state=test-state&code=invalid",
@ -1940,6 +2014,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
{ {
name: "valid", name: "valid",
query: "state=test-state&code=valid", query: "state=test-state&code=valid",
wantHTTPStatus: http.StatusOK,
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
@ -1957,8 +2032,9 @@ func TestHandleAuthCodeCallback(t *testing.T) {
{ {
name: "valid form_post", name: "valid form_post",
method: http.MethodPost, method: http.MethodPost,
contentType: "application/x-www-form-urlencoded", headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
body: []byte(`state=test-state&code=valid`), body: []byte(`state=test-state&code=valid`),
wantHTTPStatus: http.StatusOK,
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.useFormPost = true h.useFormPost = true
@ -1978,11 +2054,14 @@ func TestHandleAuthCodeCallback(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
h := &handlerState{ h := &handlerState{
callbacks: make(chan callbackResult, 1), callbacks: make(chan callbackResult, 1),
state: state.State("test-state"), state: state.State("test-state"),
pkce: pkce.Code("test-pkce"), pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"), nonce: nonce.Nonce("test-nonce"),
logger: testlogger.New(t).Logger,
} }
if tt.opt != nil { if tt.opt != nil {
require.NoError(t, tt.opt(t)(h)) require.NoError(t, tt.opt(t)(h))
@ -1998,8 +2077,8 @@ func TestHandleAuthCodeCallback(t *testing.T) {
if tt.method != "" { if tt.method != "" {
req.Method = tt.method req.Method = tt.method
} }
if tt.contentType != "" { if tt.headers != nil {
req.Header.Set("Content-Type", tt.contentType) req.Header = tt.headers
} }
err = h.handleAuthCodeCallback(resp, req) err = h.handleAuthCodeCallback(resp, req)
@ -2012,11 +2091,18 @@ func TestHandleAuthCodeCallback(t *testing.T) {
} }
} else { } else {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tt.wantHTTPStatus, resp.Code)
}
if tt.wantHeaders != nil {
require.Equal(t, tt.wantHeaders, resp.Header())
} }
select { select {
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
if !tt.wantNoCallbacks {
require.Fail(t, "timed out waiting to receive from callbacks channel") require.Fail(t, "timed out waiting to receive from callbacks channel")
}
case result := <-h.callbacks: case result := <-h.callbacks:
if tt.wantErr != "" { if tt.wantErr != "" {
require.EqualError(t, result.err, tt.wantErr) require.EqualError(t, result.err, tt.wantErr)