From 602f3c59babdb7090771c4b61fe8d9db3cd1eed1 Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Tue, 15 Dec 2020 21:34:34 -0600 Subject: [PATCH] Fix a regression in securityheader package. The bug itself has to do with when headers are streamed to the client. Once a wrapped handler has sent any bytes to the `http.ResponseWriter`, the value of the map returned from `w.Header()` no longer matters for the response. The fix is fairly trivial, which is to add those response headers before invoking the wrapped handler. The existing unit test didn't catch this due to limitations in `httptest.NewRecorder()`. It is now replaced with a new test that runs a full HTTP test server, which catches the previous bug. Signed-off-by: Matt Moyer --- .../httputil/securityheader/securityheader.go | 3 +- .../securityheader/securityheader_test.go | 37 +++++++++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/internal/httputil/securityheader/securityheader.go b/internal/httputil/securityheader/securityheader.go index 42cf1e2f..2def7ede 100644 --- a/internal/httputil/securityheader/securityheader.go +++ b/internal/httputil/securityheader/securityheader.go @@ -9,7 +9,6 @@ import "net/http" // Wrap the provided http.Handler so it sets appropriate security-related response headers. func Wrap(wrapped http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - wrapped.ServeHTTP(w, r) h := w.Header() h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") h.Set("X-Frame-Options", "DENY") @@ -26,5 +25,7 @@ func Wrap(wrapped http.Handler) http.Handler { h.Set("Pragma", "no-cache") h.Set("Expires", "0") + + wrapped.ServeHTTP(w, r) }) } diff --git a/internal/httputil/securityheader/securityheader_test.go b/internal/httputil/securityheader/securityheader_test.go index 715e7cd5..e8772f51 100644 --- a/internal/httputil/securityheader/securityheader_test.go +++ b/internal/httputil/securityheader/securityheader_test.go @@ -4,22 +4,40 @@ package securityheader import ( + "context" + "io/ioutil" "net/http" "net/http/httptest" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestWrap(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + testServer := httptest.NewServer(Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test-Header", "test value") _, _ = w.Write([]byte("hello world")) - }) - rec := httptest.NewRecorder() - Wrap(handler).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil)) - require.Equal(t, http.StatusOK, rec.Code) - require.Equal(t, "hello world", rec.Body.String()) - require.EqualValues(t, http.Header{ + }))) + t.Cleanup(testServer.Close) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "hello world", string(respBody)) + + expected := http.Header{ + "X-Test-Header": []string{"test value"}, "Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"}, "Content-Type": []string{"text/plain; charset=utf-8"}, "Referrer-Policy": []string{"no-referrer"}, @@ -30,5 +48,8 @@ func TestWrap(t *testing.T) { "Cache-Control": []string{"no-cache", "no-store", "max-age=0", "must-revalidate"}, "Pragma": []string{"no-cache"}, "Expires": []string{"0"}, - }, rec.Header()) + } + for key, values := range expected { + assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key) + } }