Use security headers for the form_post page in the POST /login endpoint

Also use more specific test assertions where security headers are
expected. And run the unit tests for the login package in parallel.
This commit is contained in:
Ryan Richard 2022-05-03 16:46:09 -07:00
parent 388cdb6ddd
commit 2e031f727b
7 changed files with 78 additions and 25 deletions

View File

@ -2566,7 +2566,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
testutil.RequireSecurityHeaders(t, rsp)
testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp)
if test.wantPasswordGrantCall != nil {
test.wantPasswordGrantCall.args.Ctx = reqContext

View File

@ -1034,7 +1034,7 @@ func TestCallbackEndpoint(t *testing.T) {
t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String())
testutil.RequireSecurityHeaders(t, rsp)
testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp)
if test.wantAuthcodeExchangeCall != nil {
test.wantAuthcodeExchangeCall.args.Ctx = reqContext

View File

@ -96,7 +96,10 @@ func TestGetLogin(t *testing.T) {
for _, test := range tests {
tt := test
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := NewGetHandler(tt.idps)
target := "/some/path/login?state=" + tt.encodedState
if tt.errParam != "" {
@ -107,7 +110,7 @@ func TestGetLogin(t *testing.T) {
err := handler(rsp, req, tt.encodedState, tt.decodedState)
require.NoError(t, err)
require.Equal(t, test.wantStatus, rsp.Code)
require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
body := rsp.Body.String()
require.Equal(t, tt.wantBody, body)

View File

@ -11,6 +11,7 @@ import (
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/internal/plog"
)
@ -78,7 +79,22 @@ func NewHandler(
return handler(w, r, encodedState, decodedState)
})
return securityheader.Wrap(loginHandler)
return wrapSecurityHeaders(loginHandler)
}
func wrapSecurityHeaders(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var wrapped http.Handler
switch r.Method {
case http.MethodPost:
// POST requests can result in the form_post html page, so allow it with CSP headers.
wrapped = securityheader.WrapWithCustomCSP(handler, formposthtml.ContentSecurityPolicy())
default:
wrapped = securityheader.Wrap(handler)
}
wrapped.ServeHTTP(w, r)
})
}
func RedirectToLoginPage(

View File

@ -370,9 +370,11 @@ func TestLoginEndpoint(t *testing.T) {
tt := test
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.path, nil)
if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie)
if tt.csrfCookie != "" {
req.Header.Set("Cookie", tt.csrfCookie)
}
rsp := httptest.NewRecorder()
@ -414,7 +416,11 @@ func TestLoginEndpoint(t *testing.T) {
subject.ServeHTTP(rsp, req)
testutil.RequireSecurityHeaders(t, rsp)
if tt.method == http.MethodPost {
testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp)
} else {
testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp)
}
require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)

View File

@ -617,6 +617,8 @@ func TestPostLoginEndpoint(t *testing.T) {
tt := test
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeClient := fake.NewSimpleClientset()
secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
@ -650,7 +652,7 @@ func TestPostLoginEndpoint(t *testing.T) {
require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
require.Equal(t, test.wantBodyString, rsp.Body.String())
require.Equal(t, tt.wantBodyString, rsp.Body.String())
actualLocation := rsp.Header().Get("Location")
@ -660,30 +662,30 @@ func TestPostLoginEndpoint(t *testing.T) {
oidctestutil.RequireAuthCodeRegexpMatch(
t,
actualLocation,
test.wantRedirectLocationRegexp,
tt.wantRedirectLocationRegexp,
kubeClient,
secretsClient,
kubeOauthStore,
test.wantDownstreamGrantedScopes,
test.wantDownstreamIDTokenSubject,
test.wantDownstreamIDTokenUsername,
test.wantDownstreamIDTokenGroups,
test.wantDownstreamRequestedScopes,
test.wantDownstreamPKCEChallenge,
test.wantDownstreamPKCEChallengeMethod,
test.wantDownstreamNonce,
tt.wantDownstreamGrantedScopes,
tt.wantDownstreamIDTokenSubject,
tt.wantDownstreamIDTokenUsername,
tt.wantDownstreamIDTokenGroups,
tt.wantDownstreamRequestedScopes,
tt.wantDownstreamPKCEChallenge,
tt.wantDownstreamPKCEChallengeMethod,
tt.wantDownstreamNonce,
downstreamClientID,
test.wantDownstreamRedirectURI,
test.wantDownstreamCustomSessionData,
tt.wantDownstreamRedirectURI,
tt.wantDownstreamCustomSessionData,
)
case tt.wantRedirectToLoginPageError != "":
expectedLocation := downstreamIssuer + oidc.PinnipedLoginPath +
"?err=" + tt.wantRedirectToLoginPageError + "&state=" + happyEncodedUpstreamState
require.Equal(t, expectedLocation, actualLocation)
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords)
require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords)
case tt.wantRedirectLocationString != "":
require.Equal(t, tt.wantRedirectLocationString, actualLocation)
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords)
require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords)
default:
require.Failf(t, "test should have expected a redirect",
"actual location was %q", actualLocation)

View File

@ -1,4 +1,4 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
@ -54,9 +54,35 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret
require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets)
}
func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
// This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use.
require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'")
func RequireSecurityHeadersWithFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) {
// Loosely confirm that the unique CSPs needed for the form_post page were used.
cspHeader := response.Header().Get("Content-Security-Policy")
require.Contains(t, cspHeader, "script-src '") // loose assertion
require.Contains(t, cspHeader, "style-src '") // loose assertion
require.Contains(t, cspHeader, "img-src data:")
require.Contains(t, cspHeader, "connect-src *")
// Also require all the usual security headers.
requireSecurityHeaders(t, response)
}
func RequireSecurityHeadersWithoutFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) {
// Confirm that the unique CSPs needed for the form_post page were NOT used.
cspHeader := response.Header().Get("Content-Security-Policy")
require.NotContains(t, cspHeader, "script-src")
require.NotContains(t, cspHeader, "style-src")
require.NotContains(t, cspHeader, "img-src data:")
require.NotContains(t, cspHeader, "connect-src *")
// Also require all the usual security headers.
requireSecurityHeaders(t, response)
}
func requireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
// Loosely confirm that the generic CSPs were used.
cspHeader := response.Header().Get("Content-Security-Policy")
require.Contains(t, cspHeader, "default-src 'none'")
require.Contains(t, cspHeader, "frame-ancestors 'none'")
require.Equal(t, "DENY", response.Header().Get("X-Frame-Options"))
require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))