Fix a regression in securityheader package.

The bug itself has to do with when headers are streamed to the client. Once a wrapped handler has sent any bytes to the `http.ResponseWriter`, the value of the map returned from `w.Header()` no longer matters for the response. The fix is fairly trivial, which is to add those response headers before invoking the wrapped handler.

The existing unit test didn't catch this due to limitations in `httptest.NewRecorder()`. It is now replaced with a new test that runs a full HTTP test server, which catches the previous bug.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-12-15 21:34:34 -06:00
parent bc1dc0805e
commit 602f3c59ba
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
2 changed files with 31 additions and 9 deletions

View File

@ -9,7 +9,6 @@ import "net/http"
// Wrap the provided http.Handler so it sets appropriate security-related response headers.
func Wrap(wrapped http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrapped.ServeHTTP(w, r)
h := w.Header()
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'")
h.Set("X-Frame-Options", "DENY")
@ -26,5 +25,7 @@ func Wrap(wrapped http.Handler) http.Handler {
h.Set("Pragma", "no-cache")
h.Set("Expires", "0")
wrapped.ServeHTTP(w, r)
})
}

View File

@ -4,22 +4,40 @@
package securityheader
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
testServer := httptest.NewServer(Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Test-Header", "test value")
_, _ = w.Write([]byte("hello world"))
})
rec := httptest.NewRecorder()
Wrap(handler).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "hello world", rec.Body.String())
require.EqualValues(t, http.Header{
})))
t.Cleanup(testServer.Close)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, testServer.URL, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
respBody, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "hello world", string(respBody))
expected := http.Header{
"X-Test-Header": []string{"test value"},
"Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"Referrer-Policy": []string{"no-referrer"},
@ -30,5 +48,8 @@ func TestWrap(t *testing.T) {
"Cache-Control": []string{"no-cache", "no-store", "max-age=0", "must-revalidate"},
"Pragma": []string{"no-cache"},
"Expires": []string{"0"},
}, rec.Header())
}
for key, values := range expected {
assert.Equalf(t, values, resp.Header.Values(key), "unexpected values for header %s", key)
}
}