Add Cache-Control, Pragma, Expires, and X-DNS-Prefetch-Control headers

Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
Ryan Richard 2020-12-14 15:28:32 -08:00 committed by Margo Crawford
parent a5c07042c1
commit 16907e4453
7 changed files with 37 additions and 5 deletions

View File

@ -9,12 +9,22 @@ import "net/http"
// Wrap the provided http.Handler so it sets appropriate security-related response headers. // Wrap the provided http.Handler so it sets appropriate security-related response headers.
func Wrap(wrapped http.Handler) http.Handler { func Wrap(wrapped http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wrapped.ServeHTTP(w, r)
h := w.Header() h := w.Header()
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'") h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'")
h.Set("X-Frame-Options", "DENY") h.Set("X-Frame-Options", "DENY")
h.Set("X-XSS-Protection", "1; mode=block") h.Set("X-XSS-Protection", "1; mode=block")
h.Set("X-Content-Type-Options", "nosniff") h.Set("X-Content-Type-Options", "nosniff")
h.Set("Referrer-Policy", "no-referrer") h.Set("Referrer-Policy", "no-referrer")
wrapped.ServeHTTP(w, r) h.Set("X-DNS-Prefetch-Control", "off")
// first overwrite existing Cache-Control header with Set, then append more headers with Add
h.Set("Cache-Control", "no-cache")
h.Add("Cache-Control", "no-store")
h.Add("Cache-Control", "max-age=0")
h.Add("Cache-Control", "must-revalidate")
h.Set("Pragma", "no-cache")
h.Set("Expires", "0")
}) })
} }

View File

@ -26,5 +26,9 @@ func TestWrap(t *testing.T) {
"X-Content-Type-Options": []string{"nosniff"}, "X-Content-Type-Options": []string{"nosniff"},
"X-Frame-Options": []string{"DENY"}, "X-Frame-Options": []string{"DENY"},
"X-Xss-Protection": []string{"1; mode=block"}, "X-Xss-Protection": []string{"1; mode=block"},
"X-Dns-Prefetch-Control": []string{"off"},
"Cache-Control": []string{"no-cache", "no-store", "max-age=0", "must-revalidate"},
"Pragma": []string{"no-cache"},
"Expires": []string{"0"},
}, rec.Header()) }, rec.Header())
} }

View File

@ -16,6 +16,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
@ -34,7 +35,7 @@ func NewHandler(
upstreamStateEncoder oidc.Encoder, upstreamStateEncoder oidc.Encoder,
cookieCodec oidc.Codec, cookieCodec oidc.Codec,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return securityheader.Wrap(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost && r.Method != http.MethodGet { if r.Method != http.MethodPost && r.Method != http.MethodGet {
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest // https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// Authorization Servers MUST support the use of the HTTP GET and POST methods defined in // Authorization Servers MUST support the use of the HTTP GET and POST methods defined in
@ -142,7 +143,7 @@ func NewHandler(
) )
return nil return nil
}) }))
} }
func readCSRFCookie(r *http.Request, codec oidc.Decoder) csrftoken.CSRFToken { func readCSRFCookie(r *http.Request, codec oidc.Decoder) csrftoken.CSRFToken {

View File

@ -773,6 +773,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
testutil.RequireSecurityHeaders(t, rsp)
actualLocation := rsp.Header().Get("Location") actualLocation := rsp.Header().Get("Location")
if test.wantLocationHeader != "" { if test.wantLocationHeader != "" {

View File

@ -17,6 +17,7 @@ import (
"github.com/ory/fosite/token/jwt" "github.com/ory/fosite/token/jwt"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
@ -45,7 +46,7 @@ func NewHandler(
stateDecoder, cookieDecoder oidc.Decoder, stateDecoder, cookieDecoder oidc.Decoder,
redirectURI string, redirectURI string,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return securityheader.Wrap(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder) state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil { if err != nil {
return err return err
@ -108,7 +109,7 @@ func NewHandler(
oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder) oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder)
return nil return nil
}) }))
} }
func authcode(r *http.Request) string { func authcode(r *http.Request) string {

View File

@ -477,6 +477,8 @@ func TestCallbackEndpoint(t *testing.T) {
t.Logf("response: %#v", rsp) t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
testutil.RequireSecurityHeaders(t, rsp)
if test.wantExchangeAndValidateTokensCall != nil { if test.wantExchangeAndValidateTokensCall != nil {
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
test.wantExchangeAndValidateTokensCall.Ctx = req.Context() test.wantExchangeAndValidateTokensCall.Ctx = req.Context()

View File

@ -6,6 +6,7 @@ package testutil
import ( import (
"context" "context"
"mime" "mime"
"net/http/httptest"
"testing" "testing"
"time" "time"
@ -52,3 +53,15 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret
require.NoError(t, err) require.NoError(t, err)
require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets) require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets)
} }
func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
require.Equal(t, "default-src 'none'; frame-ancestors 'none'", response.Header().Get("Content-Security-Policy"))
require.Equal(t, "DENY", response.Header().Get("X-Frame-Options"))
require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))
require.Equal(t, "nosniff", response.Header().Get("X-Content-Type-Options"))
require.Equal(t, "no-referrer", response.Header().Get("Referrer-Policy"))
require.Equal(t, "off", response.Header().Get("X-DNS-Prefetch-Control"))
require.ElementsMatch(t, []string{"no-cache", "no-store", "max-age=0", "must-revalidate"}, response.Header().Values("Cache-Control"))
require.Equal(t, "no-cache", response.Header().Get("Pragma"))
require.Equal(t, "0", response.Header().Get("Expires"))
}