diff --git a/internal/httputil/securityheader/securityheader.go b/internal/httputil/securityheader/securityheader.go index 42cf1e2f..2bb3af12 100644 --- a/internal/httputil/securityheader/securityheader.go +++ b/internal/httputil/securityheader/securityheader.go @@ -9,7 +9,6 @@ import "net/http" // Wrap the provided http.Handler so it sets appropriate security-related response headers. func Wrap(wrapped http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wrapped.ServeHTTP(w, r) h := w.Header() h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") h.Set("X-Frame-Options", "DENY") @@ -17,14 +16,9 @@ func Wrap(wrapped http.Handler) http.Handler { h.Set("X-Content-Type-Options", "nosniff") h.Set("Referrer-Policy", "no-referrer") h.Set("X-DNS-Prefetch-Control", "off") - - // first overwrite existing Cache-Control header with Set, then append more headers with Add - h.Set("Cache-Control", "no-cache") - h.Add("Cache-Control", "no-store") - h.Add("Cache-Control", "max-age=0") - h.Add("Cache-Control", "must-revalidate") - + h.Set("Cache-Control", "no-cache,no-store,max-age=0,must-revalidate") h.Set("Pragma", "no-cache") h.Set("Expires", "0") + wrapped.ServeHTTP(w, r) }) } diff --git a/internal/httputil/securityheader/securityheader_test.go b/internal/httputil/securityheader/securityheader_test.go index 715e7cd5..a0688c1a 100644 --- a/internal/httputil/securityheader/securityheader_test.go +++ b/internal/httputil/securityheader/securityheader_test.go @@ -4,22 +4,40 @@ package securityheader import ( + "context" + "io/ioutil" "net/http" "net/http/httptest" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestWrap(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer := httptest.NewServer(Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test-Header", "test value") _, _ = w.Write([]byte("hello world")) - }) - rec := httptest.NewRecorder() - Wrap(handler).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "hello world", rec.Body.String()) - require.EqualValues(t, http.Header{ + }))) + t.Cleanup(testServer.Close) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "hello world", string(respBody)) + + expected := http.Header{ + "X-Test-Header": []string{"test value"}, "Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"}, "Content-Type": []string{"text/plain; charset=utf-8"}, "Referrer-Policy": []string{"no-referrer"}, @@ -27,8 +45,11 @@ func TestWrap(t *testing.T) { "X-Frame-Options": []string{"DENY"}, "X-Xss-Protection": []string{"1; mode=block"}, "X-Dns-Prefetch-Control": []string{"off"}, - "Cache-Control": []string{"no-cache", "no-store", "max-age=0", "must-revalidate"}, + "Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"}, "Pragma": []string{"no-cache"}, "Expires": []string{"0"}, - }, rec.Header()) + } + for key, values := range expected { + assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key) + } } diff --git a/internal/testutil/assertions.go b/internal/testutil/assertions.go index b0c3018d..54fc8563 100644 --- a/internal/testutil/assertions.go +++ b/internal/testutil/assertions.go @@ -61,7 +61,9 @@ func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) { require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options")) require.Equal(t, "no-referrer", response.Header().Get("Referrer-Policy")) require.Equal(t, "off", response.Header().Get("X-DNS-Prefetch-Control")) - require.ElementsMatch(t, []string{"no-cache", "no-store", "max-age=0", "must-revalidate"}, response.Header().Values("Cache-Control")) require.Equal(t, "no-cache", response.Header().Get("Pragma")) require.Equal(t, "0", response.Header().Get("Expires")) + + // This check is more relaxed since Fosite can override the base header we set. + require.Contains(t, response.Header().Get("Cache-Control"), "no-store") } diff --git a/test/integration/supervisor_login_test.go b/test/integration/supervisor_login_test.go index 6fa7875f..6fc03a7d 100644 --- a/test/integration/supervisor_login_test.go +++ b/test/integration/supervisor_login_test.go @@ -57,19 +57,25 @@ func TestSupervisorLogin(t *testing.T) { require.NoError(t, err) // Create an HTTP client that can reach the downstream discovery endpoint using the CA certs. - httpClient := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: ca.Pool()}, - Proxy: func(req *http.Request) (*url.URL, error) { - if env.Proxy == "" { - t.Logf("passing request for %s with no proxy", req.URL) - return nil, nil - } - proxyURL, err := url.Parse(env.Proxy) - require.NoError(t, err) - t.Logf("passing request for %s through proxy %s", req.URL, proxyURL.String()) - return proxyURL, nil + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: ca.Pool()}, + Proxy: func(req *http.Request) (*url.URL, error) { + if env.Proxy == "" { + t.Logf("passing request for %s with no proxy", req.URL) + return nil, nil + } + proxyURL, err := url.Parse(env.Proxy) + require.NoError(t, err) + t.Logf("passing request for %s through proxy %s", req.URL, proxyURL.String()) + return proxyURL, nil + }, }, - }} + // Don't follow redirects automatically. + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } oidcHTTPClientContext := coreosoidc.ClientContext(ctx, httpClient) // Use the CA to issue a TLS server cert. @@ -144,6 +150,14 @@ func TestSupervisorLogin(t *testing.T) { pkceParam.Method(), ) + // Make the authorize request one "manually" so we can check its response headers. + authorizeRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthorizeURL, nil) + require.NoError(t, err) + authorizeResp, err := httpClient.Do(authorizeRequest) + require.NoError(t, err) + require.NoError(t, authorizeResp.Body.Close()) + expectSecurityHeaders(t, authorizeResp) + // Open the web browser and navigate to the downstream authorize URL. page := browsertest.Open(t) t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL)) @@ -306,3 +320,16 @@ func doTokenExchange(t *testing.T, config *oauth2.Config, tokenResponse *oauth2. require.NoError(t, err) t.Logf("exchanged token claims:\n%s", string(indentedClaims)) } + +func expectSecurityHeaders(t *testing.T, response *http.Response) { + h := response.Header + assert.Equal(t, "default-src 'none'; frame-ancestors 'none'", h.Get("Content-Security-Policy")) + assert.Equal(t, "DENY", h.Get("X-Frame-Options")) + assert.Equal(t, "1; mode=block", h.Get("X-XSS-Protection")) + assert.Equal(t, "nosniff", h.Get("X-Content-Type-Options")) + assert.Equal(t, "no-referrer", h.Get("Referrer-Policy")) + assert.Equal(t, "off", h.Get("X-DNS-Prefetch-Control")) + assert.Equal(t, "no-cache,no-store,max-age=0,must-revalidate", h.Get("Cache-Control")) + assert.Equal(t, "no-cache", h.Get("Pragma")) + assert.Equal(t, "0", h.Get("Expires")) +}