Add CORS request handling to CLI's localhost listener
This is to support the new changes in Google Chrome v98 which now performs CORS preflight requests for the Javascript form submission on the Supervisor's login page, even though the form is being submitted to a localhost listener.
This commit is contained in:
parent
7c246784dc
commit
7b97f1533e
@ -834,10 +834,41 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
||||
}()
|
||||
|
||||
var params url.Values
|
||||
if h.useFormPost {
|
||||
if h.useFormPost { // nolint:nestif
|
||||
if r.Method == http.MethodOptions {
|
||||
// Google Chrome decided that it should do CORS preflight checks for this Javascript form submission POST request.
|
||||
// See https://developer.chrome.com/blog/private-network-access-preflight/
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
// The CORS preflight request should have an origin.
|
||||
h.logger.V(debugLogLevel).Info("Pinniped: Got OPTIONS request without origin header")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return nil // keep listening for more requests
|
||||
}
|
||||
h.logger.V(debugLogLevel).Info("Pinniped: Got CORS preflight request from browser", "origin", origin)
|
||||
issuerURL, parseErr := url.Parse(h.issuer)
|
||||
if parseErr != nil {
|
||||
return httperr.Wrap(http.StatusInternalServerError, "invalid issuer url", parseErr)
|
||||
}
|
||||
// To tell the browser that it is okay to make the real POST request, return the following response.
|
||||
w.Header().Set("Access-Control-Allow-Origin", issuerURL.Scheme+"://"+issuerURL.Host)
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "false")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Private-Network", "true")
|
||||
// If the browser would like to send some headers on the real request, allow them. Chrome doesn't
|
||||
// currently send this header at the moment. This is in case some browser in the future decides to
|
||||
// request to be allowed to send specific headers by using Access-Control-Request-Headers.
|
||||
requestedHeaders := r.Header.Get("Access-Control-Request-Headers")
|
||||
if requestedHeaders != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", requestedHeaders)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return nil // keep listening for more requests
|
||||
}
|
||||
|
||||
// Return HTTP 405 for anything that's not a POST.
|
||||
if r.Method != http.MethodPost {
|
||||
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST")
|
||||
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST but got %s", r.Method)
|
||||
}
|
||||
|
||||
// Parse and pull the response parameters from a application/x-www-form-urlencoded request body.
|
||||
@ -848,7 +879,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
||||
} else {
|
||||
// Return HTTP 405 for anything that's not a GET.
|
||||
if r.Method != http.MethodGet {
|
||||
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
|
||||
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET but got %s", r.Method)
|
||||
}
|
||||
|
||||
// Pull response parameters from the URL query string.
|
||||
|
@ -1825,6 +1825,8 @@ func TestHandlePasteCallback(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := &handlerState{
|
||||
callbacks: make(chan callbackResult, 1),
|
||||
state: state.State("test-state"),
|
||||
@ -1866,35 +1868,38 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
query string
|
||||
body []byte
|
||||
contentType string
|
||||
opt func(t *testing.T) Option
|
||||
wantErr string
|
||||
wantHTTPStatus int
|
||||
name string
|
||||
method string
|
||||
query string
|
||||
body []byte
|
||||
headers http.Header
|
||||
opt func(t *testing.T) Option
|
||||
|
||||
wantErr string
|
||||
wantHTTPStatus int
|
||||
wantNoCallbacks bool
|
||||
wantHeaders http.Header
|
||||
}{
|
||||
{
|
||||
name: "wrong method",
|
||||
method: "POST",
|
||||
method: http.MethodPost,
|
||||
query: "",
|
||||
wantErr: "wanted GET",
|
||||
wantErr: "wanted GET but got POST",
|
||||
wantHTTPStatus: http.StatusMethodNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "wrong method for form_post",
|
||||
method: "GET",
|
||||
method: http.MethodGet,
|
||||
query: "",
|
||||
opt: withFormPostMode,
|
||||
wantErr: "wanted POST",
|
||||
wantErr: "wanted POST but got GET",
|
||||
wantHTTPStatus: http.StatusMethodNotAllowed,
|
||||
},
|
||||
{
|
||||
name: "invalid form for form_post",
|
||||
method: "POST",
|
||||
method: http.MethodPost,
|
||||
query: "",
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
|
||||
body: []byte(`%`),
|
||||
opt: withFormPostMode,
|
||||
wantErr: `invalid form: invalid URL escape "%"`,
|
||||
@ -1918,6 +1923,75 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
wantErr: `login failed with code "some_error": optional error description`,
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "in form post mode, invalid issuer url config during CORS preflight request returns an error",
|
||||
method: http.MethodOptions,
|
||||
query: "",
|
||||
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
|
||||
wantErr: `invalid issuer url: parse "://bad-url": missing protocol scheme`,
|
||||
wantHTTPStatus: http.StatusInternalServerError,
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.useFormPost = true
|
||||
h.issuer = "://bad-url"
|
||||
return nil
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "in form post mode, options request is missing origin header results in 400 and keeps listener running",
|
||||
method: http.MethodOptions,
|
||||
query: "",
|
||||
opt: withFormPostMode,
|
||||
wantNoCallbacks: true,
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "in form post mode, valid CORS request responds with 402 and CORS headers and keeps listener running",
|
||||
method: http.MethodOptions,
|
||||
query: "",
|
||||
headers: map[string][]string{"Origin": {"https://some-origin.com"}},
|
||||
wantNoCallbacks: true,
|
||||
wantHTTPStatus: http.StatusNoContent,
|
||||
wantHeaders: map[string][]string{
|
||||
"Access-Control-Allow-Credentials": {"false"},
|
||||
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
|
||||
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
|
||||
"Access-Control-Allow-Private-Network": {"true"},
|
||||
},
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.useFormPost = true
|
||||
h.issuer = "https://valid-issuer.com/with/some/path"
|
||||
return nil
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "in form post mode, valid CORS request with Access-Control-Request-Headers responds with 402 and CORS headers including Access-Control-Allow-Headers and keeps listener running",
|
||||
method: http.MethodOptions,
|
||||
query: "",
|
||||
headers: map[string][]string{
|
||||
"Origin": {"https://some-origin.com"},
|
||||
"Access-Control-Request-Headers": {"header1, header2, header3"},
|
||||
},
|
||||
wantNoCallbacks: true,
|
||||
wantHTTPStatus: http.StatusNoContent,
|
||||
wantHeaders: map[string][]string{
|
||||
"Access-Control-Allow-Credentials": {"false"},
|
||||
"Access-Control-Allow-Methods": {"POST, OPTIONS"},
|
||||
"Access-Control-Allow-Origin": {"https://valid-issuer.com"},
|
||||
"Access-Control-Allow-Private-Network": {"true"},
|
||||
"Access-Control-Allow-Headers": {"header1, header2, header3"},
|
||||
},
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.useFormPost = true
|
||||
h.issuer = "https://valid-issuer.com/with/some/path"
|
||||
return nil
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid code",
|
||||
query: "state=test-state&code=invalid",
|
||||
@ -1938,8 +2012,9 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
query: "state=test-state&code=valid",
|
||||
name: "valid",
|
||||
query: "state=test-state&code=valid",
|
||||
wantHTTPStatus: http.StatusOK,
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
|
||||
@ -1955,10 +2030,11 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid form_post",
|
||||
method: http.MethodPost,
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
body: []byte(`state=test-state&code=valid`),
|
||||
name: "valid form_post",
|
||||
method: http.MethodPost,
|
||||
headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}},
|
||||
body: []byte(`state=test-state&code=valid`),
|
||||
wantHTTPStatus: http.StatusOK,
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.useFormPost = true
|
||||
@ -1978,11 +2054,14 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := &handlerState{
|
||||
callbacks: make(chan callbackResult, 1),
|
||||
state: state.State("test-state"),
|
||||
pkce: pkce.Code("test-pkce"),
|
||||
nonce: nonce.Nonce("test-nonce"),
|
||||
logger: testlogger.New(t).Logger,
|
||||
}
|
||||
if tt.opt != nil {
|
||||
require.NoError(t, tt.opt(t)(h))
|
||||
@ -1998,8 +2077,8 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
if tt.method != "" {
|
||||
req.Method = tt.method
|
||||
}
|
||||
if tt.contentType != "" {
|
||||
req.Header.Set("Content-Type", tt.contentType)
|
||||
if tt.headers != nil {
|
||||
req.Header = tt.headers
|
||||
}
|
||||
|
||||
err = h.handleAuthCodeCallback(resp, req)
|
||||
@ -2012,11 +2091,18 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantHTTPStatus, resp.Code)
|
||||
}
|
||||
|
||||
if tt.wantHeaders != nil {
|
||||
require.Equal(t, tt.wantHeaders, resp.Header())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(1 * time.Second):
|
||||
require.Fail(t, "timed out waiting to receive from callbacks channel")
|
||||
if !tt.wantNoCallbacks {
|
||||
require.Fail(t, "timed out waiting to receive from callbacks channel")
|
||||
}
|
||||
case result := <-h.callbacks:
|
||||
if tt.wantErr != "" {
|
||||
require.EqualError(t, result.err, tt.wantErr)
|
||||
|
Loading…
Reference in New Issue
Block a user