Use security headers for the form_post page in the POST /login endpoint
Also use more specific test assertions where security headers are expected. And run the unit tests for the login package in parallel.
This commit is contained in:
parent
388cdb6ddd
commit
2e031f727b
@ -2566,7 +2566,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
|
||||
require.Equal(t, test.wantStatus, rsp.Code)
|
||||
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
|
||||
testutil.RequireSecurityHeaders(t, rsp)
|
||||
testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp)
|
||||
|
||||
if test.wantPasswordGrantCall != nil {
|
||||
test.wantPasswordGrantCall.args.Ctx = reqContext
|
||||
|
@ -1034,7 +1034,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
t.Logf("response: %#v", rsp)
|
||||
t.Logf("response body: %q", rsp.Body.String())
|
||||
|
||||
testutil.RequireSecurityHeaders(t, rsp)
|
||||
testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp)
|
||||
|
||||
if test.wantAuthcodeExchangeCall != nil {
|
||||
test.wantAuthcodeExchangeCall.args.Ctx = reqContext
|
||||
|
@ -96,7 +96,10 @@ func TestGetLogin(t *testing.T) {
|
||||
|
||||
for _, test := range tests {
|
||||
tt := test
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewGetHandler(tt.idps)
|
||||
target := "/some/path/login?state=" + tt.encodedState
|
||||
if tt.errParam != "" {
|
||||
@ -107,7 +110,7 @@ func TestGetLogin(t *testing.T) {
|
||||
err := handler(rsp, req, tt.encodedState, tt.decodedState)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, test.wantStatus, rsp.Code)
|
||||
require.Equal(t, tt.wantStatus, rsp.Code)
|
||||
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
|
||||
body := rsp.Body.String()
|
||||
require.Equal(t, tt.wantBody, body)
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/httputil/securityheader"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/provider/formposthtml"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
)
|
||||
|
||||
@ -78,7 +79,22 @@ func NewHandler(
|
||||
return handler(w, r, encodedState, decodedState)
|
||||
})
|
||||
|
||||
return securityheader.Wrap(loginHandler)
|
||||
return wrapSecurityHeaders(loginHandler)
|
||||
}
|
||||
|
||||
func wrapSecurityHeaders(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var wrapped http.Handler
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
// POST requests can result in the form_post html page, so allow it with CSP headers.
|
||||
wrapped = securityheader.WrapWithCustomCSP(handler, formposthtml.ContentSecurityPolicy())
|
||||
default:
|
||||
wrapped = securityheader.Wrap(handler)
|
||||
}
|
||||
|
||||
wrapped.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func RedirectToLoginPage(
|
||||
|
@ -370,9 +370,11 @@ func TestLoginEndpoint(t *testing.T) {
|
||||
tt := test
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(tt.method, tt.path, nil)
|
||||
if test.csrfCookie != "" {
|
||||
req.Header.Set("Cookie", test.csrfCookie)
|
||||
if tt.csrfCookie != "" {
|
||||
req.Header.Set("Cookie", tt.csrfCookie)
|
||||
}
|
||||
rsp := httptest.NewRecorder()
|
||||
|
||||
@ -414,7 +416,11 @@ func TestLoginEndpoint(t *testing.T) {
|
||||
|
||||
subject.ServeHTTP(rsp, req)
|
||||
|
||||
testutil.RequireSecurityHeaders(t, rsp)
|
||||
if tt.method == http.MethodPost {
|
||||
testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp)
|
||||
} else {
|
||||
testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp)
|
||||
}
|
||||
|
||||
require.Equal(t, tt.wantStatus, rsp.Code)
|
||||
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
|
||||
|
@ -617,6 +617,8 @@ func TestPostLoginEndpoint(t *testing.T) {
|
||||
tt := test
|
||||
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
kubeClient := fake.NewSimpleClientset()
|
||||
secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
|
||||
|
||||
@ -650,7 +652,7 @@ func TestPostLoginEndpoint(t *testing.T) {
|
||||
|
||||
require.Equal(t, tt.wantStatus, rsp.Code)
|
||||
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
|
||||
require.Equal(t, test.wantBodyString, rsp.Body.String())
|
||||
require.Equal(t, tt.wantBodyString, rsp.Body.String())
|
||||
|
||||
actualLocation := rsp.Header().Get("Location")
|
||||
|
||||
@ -660,30 +662,30 @@ func TestPostLoginEndpoint(t *testing.T) {
|
||||
oidctestutil.RequireAuthCodeRegexpMatch(
|
||||
t,
|
||||
actualLocation,
|
||||
test.wantRedirectLocationRegexp,
|
||||
tt.wantRedirectLocationRegexp,
|
||||
kubeClient,
|
||||
secretsClient,
|
||||
kubeOauthStore,
|
||||
test.wantDownstreamGrantedScopes,
|
||||
test.wantDownstreamIDTokenSubject,
|
||||
test.wantDownstreamIDTokenUsername,
|
||||
test.wantDownstreamIDTokenGroups,
|
||||
test.wantDownstreamRequestedScopes,
|
||||
test.wantDownstreamPKCEChallenge,
|
||||
test.wantDownstreamPKCEChallengeMethod,
|
||||
test.wantDownstreamNonce,
|
||||
tt.wantDownstreamGrantedScopes,
|
||||
tt.wantDownstreamIDTokenSubject,
|
||||
tt.wantDownstreamIDTokenUsername,
|
||||
tt.wantDownstreamIDTokenGroups,
|
||||
tt.wantDownstreamRequestedScopes,
|
||||
tt.wantDownstreamPKCEChallenge,
|
||||
tt.wantDownstreamPKCEChallengeMethod,
|
||||
tt.wantDownstreamNonce,
|
||||
downstreamClientID,
|
||||
test.wantDownstreamRedirectURI,
|
||||
test.wantDownstreamCustomSessionData,
|
||||
tt.wantDownstreamRedirectURI,
|
||||
tt.wantDownstreamCustomSessionData,
|
||||
)
|
||||
case tt.wantRedirectToLoginPageError != "":
|
||||
expectedLocation := downstreamIssuer + oidc.PinnipedLoginPath +
|
||||
"?err=" + tt.wantRedirectToLoginPageError + "&state=" + happyEncodedUpstreamState
|
||||
require.Equal(t, expectedLocation, actualLocation)
|
||||
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords)
|
||||
require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords)
|
||||
case tt.wantRedirectLocationString != "":
|
||||
require.Equal(t, tt.wantRedirectLocationString, actualLocation)
|
||||
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords)
|
||||
require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords)
|
||||
default:
|
||||
require.Failf(t, "test should have expected a redirect",
|
||||
"actual location was %q", actualLocation)
|
||||
|
@ -1,4 +1,4 @@
|
||||
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
|
||||
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package testutil
|
||||
@ -54,9 +54,35 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret
|
||||
require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets)
|
||||
}
|
||||
|
||||
func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
|
||||
// This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use.
|
||||
require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'")
|
||||
func RequireSecurityHeadersWithFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) {
|
||||
// Loosely confirm that the unique CSPs needed for the form_post page were used.
|
||||
cspHeader := response.Header().Get("Content-Security-Policy")
|
||||
require.Contains(t, cspHeader, "script-src '") // loose assertion
|
||||
require.Contains(t, cspHeader, "style-src '") // loose assertion
|
||||
require.Contains(t, cspHeader, "img-src data:")
|
||||
require.Contains(t, cspHeader, "connect-src *")
|
||||
|
||||
// Also require all the usual security headers.
|
||||
requireSecurityHeaders(t, response)
|
||||
}
|
||||
|
||||
func RequireSecurityHeadersWithoutFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) {
|
||||
// Confirm that the unique CSPs needed for the form_post page were NOT used.
|
||||
cspHeader := response.Header().Get("Content-Security-Policy")
|
||||
require.NotContains(t, cspHeader, "script-src")
|
||||
require.NotContains(t, cspHeader, "style-src")
|
||||
require.NotContains(t, cspHeader, "img-src data:")
|
||||
require.NotContains(t, cspHeader, "connect-src *")
|
||||
|
||||
// Also require all the usual security headers.
|
||||
requireSecurityHeaders(t, response)
|
||||
}
|
||||
|
||||
func requireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
|
||||
// Loosely confirm that the generic CSPs were used.
|
||||
cspHeader := response.Header().Get("Content-Security-Policy")
|
||||
require.Contains(t, cspHeader, "default-src 'none'")
|
||||
require.Contains(t, cspHeader, "frame-ancestors 'none'")
|
||||
|
||||
require.Equal(t, "DENY", response.Header().Get("X-Frame-Options"))
|
||||
require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))
|
||||
|
Loading…
Reference in New Issue
Block a user