Merge pull request #288 from mattmoyer/fixup-securityheaders

Fix a regression in securityheaders package and add tests.
This commit is contained in:
Matt Moyer 2020-12-16 13:46:28 -06:00 committed by GitHub
commit 2840e4e152
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 30 deletions

View File

@ -9,7 +9,6 @@ import "net/http"
// Wrap the provided http.Handler so it sets appropriate security-related response headers. // Wrap the provided http.Handler so it sets appropriate security-related response headers.
func Wrap(wrapped http.Handler) http.Handler { func Wrap(wrapped http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrapped.ServeHTTP(w, r)
h := w.Header() h := w.Header()
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'")
h.Set("X-Frame-Options", "DENY") h.Set("X-Frame-Options", "DENY")
@ -17,14 +16,9 @@ func Wrap(wrapped http.Handler) http.Handler {
h.Set("X-Content-Type-Options", "nosniff") h.Set("X-Content-Type-Options", "nosniff")
h.Set("Referrer-Policy", "no-referrer") h.Set("Referrer-Policy", "no-referrer")
h.Set("X-DNS-Prefetch-Control", "off") h.Set("X-DNS-Prefetch-Control", "off")
h.Set("Cache-Control", "no-cache,no-store,max-age=0,must-revalidate")
// first overwrite existing Cache-Control header with Set, then append more headers with Add
h.Set("Cache-Control", "no-cache")
h.Add("Cache-Control", "no-store")
h.Add("Cache-Control", "max-age=0")
h.Add("Cache-Control", "must-revalidate")
h.Set("Pragma", "no-cache") h.Set("Pragma", "no-cache")
h.Set("Expires", "0") h.Set("Expires", "0")
wrapped.ServeHTTP(w, r)
}) })
} }

View File

@ -4,22 +4,40 @@
package securityheader package securityheader
import ( import (
"context"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWrap(t *testing.T) { func TestWrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testServer := httptest.NewServer(Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test-Header", "test value")
_, _ = w.Write([]byte("hello world")) _, _ = w.Write([]byte("hello world"))
}) })))
rec := httptest.NewRecorder() t.Cleanup(testServer.Close)
Wrap(handler).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
require.Equal(t, http.StatusOK, rec.Code) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
require.Equal(t, "hello world", rec.Body.String()) defer cancel()
require.EqualValues(t, http.Header{
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "hello world", string(respBody))
expected := http.Header{
"X-Test-Header": []string{"test value"},
"Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"}, "Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"},
"Content-Type": []string{"text/plain; charset=utf-8"}, "Content-Type": []string{"text/plain; charset=utf-8"},
"Referrer-Policy": []string{"no-referrer"}, "Referrer-Policy": []string{"no-referrer"},
@ -27,8 +45,11 @@ func TestWrap(t *testing.T) {
"X-Frame-Options": []string{"DENY"}, "X-Frame-Options": []string{"DENY"},
"X-Xss-Protection": []string{"1; mode=block"}, "X-Xss-Protection": []string{"1; mode=block"},
"X-Dns-Prefetch-Control": []string{"off"}, "X-Dns-Prefetch-Control": []string{"off"},
"Cache-Control": []string{"no-cache", "no-store", "max-age=0", "must-revalidate"}, "Cache-Control": []string{"no-cache,no-store,max-age=0,must-revalidate"},
"Pragma": []string{"no-cache"}, "Pragma": []string{"no-cache"},
"Expires": []string{"0"}, "Expires": []string{"0"},
}, rec.Header()) }
for key, values := range expected {
assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key)
}
} }

View File

@ -61,7 +61,9 @@ func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options")) require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options"))
require.Equal(t, "no-referrer", response.Header().Get("Referrer-Policy")) require.Equal(t, "no-referrer", response.Header().Get("Referrer-Policy"))
require.Equal(t, "off", response.Header().Get("X-DNS-Prefetch-Control")) require.Equal(t, "off", response.Header().Get("X-DNS-Prefetch-Control"))
require.ElementsMatch(t, []string{"no-cache", "no-store", "max-age=0", "must-revalidate"}, response.Header().Values("Cache-Control"))
require.Equal(t, "no-cache", response.Header().Get("Pragma")) require.Equal(t, "no-cache", response.Header().Get("Pragma"))
require.Equal(t, "0", response.Header().Get("Expires")) require.Equal(t, "0", response.Header().Get("Expires"))
// This check is more relaxed since Fosite can override the base header we set.
require.Contains(t, response.Header().Get("Cache-Control"), "no-store")
} }

View File

@ -57,7 +57,8 @@ func TestSupervisorLogin(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Create an HTTP client that can reach the downstream discovery endpoint using the CA certs. // Create an HTTP client that can reach the downstream discovery endpoint using the CA certs.
httpClient := &http.Client{Transport: &http.Transport{ httpClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{RootCAs: ca.Pool()}, TLSClientConfig: &tls.Config{RootCAs: ca.Pool()},
Proxy: func(req *http.Request) (*url.URL, error) { Proxy: func(req *http.Request) (*url.URL, error) {
if env.Proxy == "" { if env.Proxy == "" {
@ -69,7 +70,12 @@ func TestSupervisorLogin(t *testing.T) {
t.Logf("passing request for %s through proxy %s", req.URL, proxyURL.String()) t.Logf("passing request for %s through proxy %s", req.URL, proxyURL.String())
return proxyURL, nil return proxyURL, nil
}, },
}} },
// Don't follow redirects automatically.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
oidcHTTPClientContext := coreosoidc.ClientContext(ctx, httpClient) oidcHTTPClientContext := coreosoidc.ClientContext(ctx, httpClient)
// Use the CA to issue a TLS server cert. // Use the CA to issue a TLS server cert.
@ -144,6 +150,14 @@ func TestSupervisorLogin(t *testing.T) {
pkceParam.Method(), pkceParam.Method(),
) )
// Make the authorize request one "manually" so we can check its response headers.
authorizeRequest, err := http.NewRequestWithContext(ctx, http.MethodGet, downstreamAuthorizeURL, nil)
require.NoError(t, err)
authorizeResp, err := httpClient.Do(authorizeRequest)
require.NoError(t, err)
require.NoError(t, authorizeResp.Body.Close())
expectSecurityHeaders(t, authorizeResp)
// Open the web browser and navigate to the downstream authorize URL. // Open the web browser and navigate to the downstream authorize URL.
page := browsertest.Open(t) page := browsertest.Open(t)
t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL)) t.Logf("opening browser to downstream authorize URL %s", library.MaskTokens(downstreamAuthorizeURL))
@ -306,3 +320,16 @@ func doTokenExchange(t *testing.T, config *oauth2.Config, tokenResponse *oauth2.
require.NoError(t, err) require.NoError(t, err)
t.Logf("exchanged token claims:\n%s", string(indentedClaims)) t.Logf("exchanged token claims:\n%s", string(indentedClaims))
} }
func expectSecurityHeaders(t *testing.T, response *http.Response) {
h := response.Header
assert.Equal(t, "default-src 'none'; frame-ancestors 'none'", h.Get("Content-Security-Policy"))
assert.Equal(t, "DENY", h.Get("X-Frame-Options"))
assert.Equal(t, "1; mode=block", h.Get("X-XSS-Protection"))
assert.Equal(t, "nosniff", h.Get("X-Content-Type-Options"))
assert.Equal(t, "no-referrer", h.Get("Referrer-Policy"))
assert.Equal(t, "off", h.Get("X-DNS-Prefetch-Control"))
assert.Equal(t, "no-cache,no-store,max-age=0,must-revalidate", h.Get("Cache-Control"))
assert.Equal(t, "no-cache", h.Get("Pragma"))
assert.Equal(t, "0", h.Get("Expires"))
}