diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index d9364689..223e7fb2 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -834,10 +834,41 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req }() var params url.Values - if h.useFormPost { + if h.useFormPost { // nolint:nestif + if r.Method == http.MethodOptions { + // Google Chrome decided that it should do CORS preflight checks for this Javascript form submission POST request. + // See https://developer.chrome.com/blog/private-network-access-preflight/ + origin := r.Header.Get("Origin") + if origin == "" { + // The CORS preflight request should have an origin. + h.logger.V(debugLogLevel).Info("Pinniped: Got OPTIONS request without origin header") + w.WriteHeader(http.StatusBadRequest) + return nil // keep listening for more requests + } + h.logger.V(debugLogLevel).Info("Pinniped: Got CORS preflight request from browser", "origin", origin) + issuerURL, parseErr := url.Parse(h.issuer) + if parseErr != nil { + return httperr.Wrap(http.StatusInternalServerError, "invalid issuer url", parseErr) + } + // To tell the browser that it is okay to make the real POST request, return the following response. + w.Header().Set("Access-Control-Allow-Origin", issuerURL.Scheme+"://"+issuerURL.Host) + w.Header().Set("Access-Control-Allow-Credentials", "false") + w.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Private-Network", "true") + // If the browser would like to send some headers on the real request, allow them. Chrome doesn't + // currently send this header at the moment. This is in case some browser in the future decides to + // request to be allowed to send specific headers by using Access-Control-Request-Headers. + requestedHeaders := r.Header.Get("Access-Control-Request-Headers") + if requestedHeaders != "" { + w.Header().Set("Access-Control-Allow-Headers", requestedHeaders) + } + w.WriteHeader(http.StatusNoContent) + return nil // keep listening for more requests + } + // Return HTTP 405 for anything that's not a POST. if r.Method != http.MethodPost { - return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST") + return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST but got %s", r.Method) } // Parse and pull the response parameters from a application/x-www-form-urlencoded request body. @@ -848,7 +879,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req } else { // Return HTTP 405 for anything that's not a GET. if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") + return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET but got %s", r.Method) } // Pull response parameters from the URL query string. diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 8ee920d7..bae18b49 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -1825,6 +1825,8 @@ func TestHandlePasteCallback(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := &handlerState{ callbacks: make(chan callbackResult, 1), state: state.State("test-state"), @@ -1866,35 +1868,38 @@ func TestHandleAuthCodeCallback(t *testing.T) { } } tests := []struct { - name string - method string - query string - body []byte - contentType string - opt func(t *testing.T) Option - wantErr string - wantHTTPStatus int + name string + method string + query string + body []byte + headers http.Header + opt func(t *testing.T) Option + + wantErr string + wantHTTPStatus int + wantNoCallbacks bool + wantHeaders http.Header }{ { name: "wrong method", - method: "POST", + method: http.MethodPost, query: "", - wantErr: "wanted GET", + wantErr: "wanted GET but got POST", wantHTTPStatus: http.StatusMethodNotAllowed, }, { name: "wrong method for form_post", - method: "GET", + method: http.MethodGet, query: "", opt: withFormPostMode, - wantErr: "wanted POST", + wantErr: "wanted POST but got GET", wantHTTPStatus: http.StatusMethodNotAllowed, }, { name: "invalid form for form_post", - method: "POST", + method: http.MethodPost, query: "", - contentType: "application/x-www-form-urlencoded", + headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, body: []byte(`%`), opt: withFormPostMode, wantErr: `invalid form: invalid URL escape "%"`, @@ -1918,6 +1923,75 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantErr: `login failed with code "some_error": optional error description`, wantHTTPStatus: http.StatusBadRequest, }, + { + name: "in form post mode, invalid issuer url config during CORS preflight request returns an error", + method: http.MethodOptions, + query: "", + headers: map[string][]string{"Origin": {"https://some-origin.com"}}, + wantErr: `invalid issuer url: parse "://bad-url": missing protocol scheme`, + wantHTTPStatus: http.StatusInternalServerError, + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + h.issuer = "://bad-url" + return nil + } + }, + }, + { + name: "in form post mode, options request is missing origin header results in 400 and keeps listener running", + method: http.MethodOptions, + query: "", + opt: withFormPostMode, + wantNoCallbacks: true, + wantHTTPStatus: http.StatusBadRequest, + }, + { + name: "in form post mode, valid CORS request responds with 402 and CORS headers and keeps listener running", + method: http.MethodOptions, + query: "", + headers: map[string][]string{"Origin": {"https://some-origin.com"}}, + wantNoCallbacks: true, + wantHTTPStatus: http.StatusNoContent, + wantHeaders: map[string][]string{ + "Access-Control-Allow-Credentials": {"false"}, + "Access-Control-Allow-Methods": {"POST, OPTIONS"}, + "Access-Control-Allow-Origin": {"https://valid-issuer.com"}, + "Access-Control-Allow-Private-Network": {"true"}, + }, + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + h.issuer = "https://valid-issuer.com/with/some/path" + return nil + } + }, + }, + { + name: "in form post mode, valid CORS request with Access-Control-Request-Headers responds with 402 and CORS headers including Access-Control-Allow-Headers and keeps listener running", + method: http.MethodOptions, + query: "", + headers: map[string][]string{ + "Origin": {"https://some-origin.com"}, + "Access-Control-Request-Headers": {"header1, header2, header3"}, + }, + wantNoCallbacks: true, + wantHTTPStatus: http.StatusNoContent, + wantHeaders: map[string][]string{ + "Access-Control-Allow-Credentials": {"false"}, + "Access-Control-Allow-Methods": {"POST, OPTIONS"}, + "Access-Control-Allow-Origin": {"https://valid-issuer.com"}, + "Access-Control-Allow-Private-Network": {"true"}, + "Access-Control-Allow-Headers": {"header1, header2, header3"}, + }, + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + h.issuer = "https://valid-issuer.com/with/some/path" + return nil + } + }, + }, { name: "invalid code", query: "state=test-state&code=invalid", @@ -1938,8 +2012,9 @@ func TestHandleAuthCodeCallback(t *testing.T) { }, }, { - name: "valid", - query: "state=test-state&code=valid", + name: "valid", + query: "state=test-state&code=valid", + wantHTTPStatus: http.StatusOK, opt: func(t *testing.T) Option { return func(h *handlerState) error { h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} @@ -1955,10 +2030,11 @@ func TestHandleAuthCodeCallback(t *testing.T) { }, }, { - name: "valid form_post", - method: http.MethodPost, - contentType: "application/x-www-form-urlencoded", - body: []byte(`state=test-state&code=valid`), + name: "valid form_post", + method: http.MethodPost, + headers: map[string][]string{"Content-Type": {"application/x-www-form-urlencoded"}}, + body: []byte(`state=test-state&code=valid`), + wantHTTPStatus: http.StatusOK, opt: func(t *testing.T) Option { return func(h *handlerState) error { h.useFormPost = true @@ -1978,11 +2054,14 @@ func TestHandleAuthCodeCallback(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := &handlerState{ callbacks: make(chan callbackResult, 1), state: state.State("test-state"), pkce: pkce.Code("test-pkce"), nonce: nonce.Nonce("test-nonce"), + logger: testlogger.New(t).Logger, } if tt.opt != nil { require.NoError(t, tt.opt(t)(h)) @@ -1998,8 +2077,8 @@ func TestHandleAuthCodeCallback(t *testing.T) { if tt.method != "" { req.Method = tt.method } - if tt.contentType != "" { - req.Header.Set("Content-Type", tt.contentType) + if tt.headers != nil { + req.Header = tt.headers } err = h.handleAuthCodeCallback(resp, req) @@ -2012,11 +2091,18 @@ func TestHandleAuthCodeCallback(t *testing.T) { } } else { require.NoError(t, err) + require.Equal(t, tt.wantHTTPStatus, resp.Code) + } + + if tt.wantHeaders != nil { + require.Equal(t, tt.wantHeaders, resp.Header()) } select { case <-time.After(1 * time.Second): - require.Fail(t, "timed out waiting to receive from callbacks channel") + if !tt.wantNoCallbacks { + require.Fail(t, "timed out waiting to receive from callbacks channel") + } case result := <-h.callbacks: if tt.wantErr != "" { require.EqualError(t, result.err, tt.wantErr)