Use security headers for the form_post page in the POST /login endpoint

Also use more specific test assertions where security headers are
expected. And run the unit tests for the login package in parallel.
This commit is contained in:
Ryan Richard 2022-05-03 16:46:09 -07:00
parent 388cdb6ddd
commit 2e031f727b
7 changed files with 78 additions and 25 deletions

View File

@ -2566,7 +2566,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
testutil.RequireSecurityHeaders(t, rsp) testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp)
if test.wantPasswordGrantCall != nil { if test.wantPasswordGrantCall != nil {
test.wantPasswordGrantCall.args.Ctx = reqContext test.wantPasswordGrantCall.args.Ctx = reqContext

View File

@ -1034,7 +1034,7 @@ func TestCallbackEndpoint(t *testing.T) {
t.Logf("response: %#v", rsp) t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
testutil.RequireSecurityHeaders(t, rsp) testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp)
if test.wantAuthcodeExchangeCall != nil { if test.wantAuthcodeExchangeCall != nil {
test.wantAuthcodeExchangeCall.args.Ctx = reqContext test.wantAuthcodeExchangeCall.args.Ctx = reqContext

View File

@ -96,7 +96,10 @@ func TestGetLogin(t *testing.T) {
for _, test := range tests { for _, test := range tests {
tt := test tt := test
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
handler := NewGetHandler(tt.idps) handler := NewGetHandler(tt.idps)
target := "/some/path/login?state=" + tt.encodedState target := "/some/path/login?state=" + tt.encodedState
if tt.errParam != "" { if tt.errParam != "" {
@ -107,7 +110,7 @@ func TestGetLogin(t *testing.T) {
err := handler(rsp, req, tt.encodedState, tt.decodedState) err := handler(rsp, req, tt.encodedState, tt.decodedState)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
body := rsp.Body.String() body := rsp.Body.String()
require.Equal(t, tt.wantBody, body) require.Equal(t, tt.wantBody, body)

View File

@ -11,6 +11,7 @@ import (
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -78,7 +79,22 @@ func NewHandler(
return handler(w, r, encodedState, decodedState) return handler(w, r, encodedState, decodedState)
}) })
return securityheader.Wrap(loginHandler) return wrapSecurityHeaders(loginHandler)
}
func wrapSecurityHeaders(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var wrapped http.Handler
switch r.Method {
case http.MethodPost:
// POST requests can result in the form_post html page, so allow it with CSP headers.
wrapped = securityheader.WrapWithCustomCSP(handler, formposthtml.ContentSecurityPolicy())
default:
wrapped = securityheader.Wrap(handler)
}
wrapped.ServeHTTP(w, r)
})
} }
func RedirectToLoginPage( func RedirectToLoginPage(

View File

@ -370,9 +370,11 @@ func TestLoginEndpoint(t *testing.T) {
tt := test tt := test
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(tt.method, tt.path, nil) req := httptest.NewRequest(tt.method, tt.path, nil)
if test.csrfCookie != "" { if tt.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", tt.csrfCookie)
} }
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
@ -414,7 +416,11 @@ func TestLoginEndpoint(t *testing.T) {
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
testutil.RequireSecurityHeaders(t, rsp) if tt.method == http.MethodPost {
testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp)
} else {
testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp)
}
require.Equal(t, tt.wantStatus, rsp.Code) require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)

View File

@ -617,6 +617,8 @@ func TestPostLoginEndpoint(t *testing.T) {
tt := test tt := test
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeClient := fake.NewSimpleClientset() kubeClient := fake.NewSimpleClientset()
secretsClient := kubeClient.CoreV1().Secrets("some-namespace") secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
@ -650,7 +652,7 @@ func TestPostLoginEndpoint(t *testing.T) {
require.Equal(t, tt.wantStatus, rsp.Code) require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
require.Equal(t, test.wantBodyString, rsp.Body.String()) require.Equal(t, tt.wantBodyString, rsp.Body.String())
actualLocation := rsp.Header().Get("Location") actualLocation := rsp.Header().Get("Location")
@ -660,30 +662,30 @@ func TestPostLoginEndpoint(t *testing.T) {
oidctestutil.RequireAuthCodeRegexpMatch( oidctestutil.RequireAuthCodeRegexpMatch(
t, t,
actualLocation, actualLocation,
test.wantRedirectLocationRegexp, tt.wantRedirectLocationRegexp,
kubeClient, kubeClient,
secretsClient, secretsClient,
kubeOauthStore, kubeOauthStore,
test.wantDownstreamGrantedScopes, tt.wantDownstreamGrantedScopes,
test.wantDownstreamIDTokenSubject, tt.wantDownstreamIDTokenSubject,
test.wantDownstreamIDTokenUsername, tt.wantDownstreamIDTokenUsername,
test.wantDownstreamIDTokenGroups, tt.wantDownstreamIDTokenGroups,
test.wantDownstreamRequestedScopes, tt.wantDownstreamRequestedScopes,
test.wantDownstreamPKCEChallenge, tt.wantDownstreamPKCEChallenge,
test.wantDownstreamPKCEChallengeMethod, tt.wantDownstreamPKCEChallengeMethod,
test.wantDownstreamNonce, tt.wantDownstreamNonce,
downstreamClientID, downstreamClientID,
test.wantDownstreamRedirectURI, tt.wantDownstreamRedirectURI,
test.wantDownstreamCustomSessionData, tt.wantDownstreamCustomSessionData,
) )
case tt.wantRedirectToLoginPageError != "": case tt.wantRedirectToLoginPageError != "":
expectedLocation := downstreamIssuer + oidc.PinnipedLoginPath + expectedLocation := downstreamIssuer + oidc.PinnipedLoginPath +
"?err=" + tt.wantRedirectToLoginPageError + "&state=" + happyEncodedUpstreamState "?err=" + tt.wantRedirectToLoginPageError + "&state=" + happyEncodedUpstreamState
require.Equal(t, expectedLocation, actualLocation) require.Equal(t, expectedLocation, actualLocation)
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords) require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords)
case tt.wantRedirectLocationString != "": case tt.wantRedirectLocationString != "":
require.Equal(t, tt.wantRedirectLocationString, actualLocation) require.Equal(t, tt.wantRedirectLocationString, actualLocation)
require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords) require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords)
default: default:
require.Failf(t, "test should have expected a redirect", require.Failf(t, "test should have expected a redirect",
"actual location was %q", actualLocation) "actual location was %q", actualLocation)

View File

@ -1,4 +1,4 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package testutil package testutil
@ -54,9 +54,35 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret
require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets) require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets)
} }
func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) { func RequireSecurityHeadersWithFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) {
// This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use. // Loosely confirm that the unique CSPs needed for the form_post page were used.
require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'") cspHeader := response.Header().Get("Content-Security-Policy")
require.Contains(t, cspHeader, "script-src '") // loose assertion
require.Contains(t, cspHeader, "style-src '") // loose assertion
require.Contains(t, cspHeader, "img-src data:")
require.Contains(t, cspHeader, "connect-src *")
// Also require all the usual security headers.
requireSecurityHeaders(t, response)
}
func RequireSecurityHeadersWithoutFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) {
// Confirm that the unique CSPs needed for the form_post page were NOT used.
cspHeader := response.Header().Get("Content-Security-Policy")
require.NotContains(t, cspHeader, "script-src")
require.NotContains(t, cspHeader, "style-src")
require.NotContains(t, cspHeader, "img-src data:")
require.NotContains(t, cspHeader, "connect-src *")
// Also require all the usual security headers.
requireSecurityHeaders(t, response)
}
func requireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
// Loosely confirm that the generic CSPs were used.
cspHeader := response.Header().Get("Content-Security-Policy")
require.Contains(t, cspHeader, "default-src 'none'")
require.Contains(t, cspHeader, "frame-ancestors 'none'")
require.Equal(t, "DENY", response.Header().Get("X-Frame-Options")) require.Equal(t, "DENY", response.Header().Get("X-Frame-Options"))
require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection")) require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))