diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index fc0cbc53..058cb70c 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -2566,7 +2566,7 @@ func TestAuthorizationEndpoint(t *testing.T) { require.Equal(t, test.wantStatus, rsp.Code) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) - testutil.RequireSecurityHeaders(t, rsp) + testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp) if test.wantPasswordGrantCall != nil { test.wantPasswordGrantCall.args.Ctx = reqContext diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 6fc47773..e92974d9 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -1034,7 +1034,7 @@ func TestCallbackEndpoint(t *testing.T) { t.Logf("response: %#v", rsp) t.Logf("response body: %q", rsp.Body.String()) - testutil.RequireSecurityHeaders(t, rsp) + testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp) if test.wantAuthcodeExchangeCall != nil { test.wantAuthcodeExchangeCall.args.Ctx = reqContext diff --git a/internal/oidc/login/get_login_handler_test.go b/internal/oidc/login/get_login_handler_test.go index 3235fcc5..484ee450 100644 --- a/internal/oidc/login/get_login_handler_test.go +++ b/internal/oidc/login/get_login_handler_test.go @@ -96,7 +96,10 @@ func TestGetLogin(t *testing.T) { for _, test := range tests { tt := test + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + handler := NewGetHandler(tt.idps) target := "/some/path/login?state=" + tt.encodedState if tt.errParam != "" { @@ -107,7 +110,7 @@ func TestGetLogin(t *testing.T) { err := handler(rsp, req, tt.encodedState, tt.decodedState) require.NoError(t, err) - require.Equal(t, test.wantStatus, rsp.Code) + require.Equal(t, tt.wantStatus, rsp.Code) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) body := rsp.Body.String() require.Equal(t, tt.wantBody, body) diff --git a/internal/oidc/login/login_handler.go b/internal/oidc/login/login_handler.go index 751dc9c4..ce1b3810 100644 --- a/internal/oidc/login/login_handler.go +++ b/internal/oidc/login/login_handler.go @@ -11,6 +11,7 @@ import ( "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/provider/formposthtml" "go.pinniped.dev/internal/plog" ) @@ -78,7 +79,22 @@ func NewHandler( return handler(w, r, encodedState, decodedState) }) - return securityheader.Wrap(loginHandler) + return wrapSecurityHeaders(loginHandler) +} + +func wrapSecurityHeaders(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var wrapped http.Handler + switch r.Method { + case http.MethodPost: + // POST requests can result in the form_post html page, so allow it with CSP headers. + wrapped = securityheader.WrapWithCustomCSP(handler, formposthtml.ContentSecurityPolicy()) + default: + wrapped = securityheader.Wrap(handler) + } + + wrapped.ServeHTTP(w, r) + }) } func RedirectToLoginPage( diff --git a/internal/oidc/login/login_handler_test.go b/internal/oidc/login/login_handler_test.go index 347f0760..79a0ee65 100644 --- a/internal/oidc/login/login_handler_test.go +++ b/internal/oidc/login/login_handler_test.go @@ -370,9 +370,11 @@ func TestLoginEndpoint(t *testing.T) { tt := test t.Run(tt.name, func(t *testing.T) { + t.Parallel() + req := httptest.NewRequest(tt.method, tt.path, nil) - if test.csrfCookie != "" { - req.Header.Set("Cookie", test.csrfCookie) + if tt.csrfCookie != "" { + req.Header.Set("Cookie", tt.csrfCookie) } rsp := httptest.NewRecorder() @@ -414,7 +416,11 @@ func TestLoginEndpoint(t *testing.T) { subject.ServeHTTP(rsp, req) - testutil.RequireSecurityHeaders(t, rsp) + if tt.method == http.MethodPost { + testutil.RequireSecurityHeadersWithFormPostCSPs(t, rsp) + } else { + testutil.RequireSecurityHeadersWithoutFormPostCSPs(t, rsp) + } require.Equal(t, tt.wantStatus, rsp.Code) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) diff --git a/internal/oidc/login/post_login_handler_test.go b/internal/oidc/login/post_login_handler_test.go index 1e4fa437..f74d67d8 100644 --- a/internal/oidc/login/post_login_handler_test.go +++ b/internal/oidc/login/post_login_handler_test.go @@ -617,6 +617,8 @@ func TestPostLoginEndpoint(t *testing.T) { tt := test t.Run(tt.name, func(t *testing.T) { + t.Parallel() + kubeClient := fake.NewSimpleClientset() secretsClient := kubeClient.CoreV1().Secrets("some-namespace") @@ -650,7 +652,7 @@ func TestPostLoginEndpoint(t *testing.T) { require.Equal(t, tt.wantStatus, rsp.Code) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) - require.Equal(t, test.wantBodyString, rsp.Body.String()) + require.Equal(t, tt.wantBodyString, rsp.Body.String()) actualLocation := rsp.Header().Get("Location") @@ -660,30 +662,30 @@ func TestPostLoginEndpoint(t *testing.T) { oidctestutil.RequireAuthCodeRegexpMatch( t, actualLocation, - test.wantRedirectLocationRegexp, + tt.wantRedirectLocationRegexp, kubeClient, secretsClient, kubeOauthStore, - test.wantDownstreamGrantedScopes, - test.wantDownstreamIDTokenSubject, - test.wantDownstreamIDTokenUsername, - test.wantDownstreamIDTokenGroups, - test.wantDownstreamRequestedScopes, - test.wantDownstreamPKCEChallenge, - test.wantDownstreamPKCEChallengeMethod, - test.wantDownstreamNonce, + tt.wantDownstreamGrantedScopes, + tt.wantDownstreamIDTokenSubject, + tt.wantDownstreamIDTokenUsername, + tt.wantDownstreamIDTokenGroups, + tt.wantDownstreamRequestedScopes, + tt.wantDownstreamPKCEChallenge, + tt.wantDownstreamPKCEChallengeMethod, + tt.wantDownstreamNonce, downstreamClientID, - test.wantDownstreamRedirectURI, - test.wantDownstreamCustomSessionData, + tt.wantDownstreamRedirectURI, + tt.wantDownstreamCustomSessionData, ) case tt.wantRedirectToLoginPageError != "": expectedLocation := downstreamIssuer + oidc.PinnipedLoginPath + "?err=" + tt.wantRedirectToLoginPageError + "&state=" + happyEncodedUpstreamState require.Equal(t, expectedLocation, actualLocation) - require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords) + require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords) case tt.wantRedirectLocationString != "": require.Equal(t, tt.wantRedirectLocationString, actualLocation) - require.Len(t, kubeClient.Actions(), test.wantUnnecessaryStoredRecords) + require.Len(t, kubeClient.Actions(), tt.wantUnnecessaryStoredRecords) default: require.Failf(t, "test should have expected a redirect", "actual location was %q", actualLocation) diff --git a/internal/testutil/assertions.go b/internal/testutil/assertions.go index 9286bff1..b592e07e 100644 --- a/internal/testutil/assertions.go +++ b/internal/testutil/assertions.go @@ -1,4 +1,4 @@ -// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. +// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package testutil @@ -54,9 +54,35 @@ func RequireNumberOfSecretsMatchingLabelSelector(t *testing.T, secrets v1.Secret require.Len(t, storedAuthcodeSecrets.Items, expectedNumberOfSecrets) } -func RequireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) { - // This is a more relaxed assertion rather than an exact match, so it can cover all the CSP headers we use. - require.Contains(t, response.Header().Get("Content-Security-Policy"), "default-src 'none'") +func RequireSecurityHeadersWithFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) { + // Loosely confirm that the unique CSPs needed for the form_post page were used. + cspHeader := response.Header().Get("Content-Security-Policy") + require.Contains(t, cspHeader, "script-src '") // loose assertion + require.Contains(t, cspHeader, "style-src '") // loose assertion + require.Contains(t, cspHeader, "img-src data:") + require.Contains(t, cspHeader, "connect-src *") + + // Also require all the usual security headers. + requireSecurityHeaders(t, response) +} + +func RequireSecurityHeadersWithoutFormPostCSPs(t *testing.T, response *httptest.ResponseRecorder) { + // Confirm that the unique CSPs needed for the form_post page were NOT used. + cspHeader := response.Header().Get("Content-Security-Policy") + require.NotContains(t, cspHeader, "script-src") + require.NotContains(t, cspHeader, "style-src") + require.NotContains(t, cspHeader, "img-src data:") + require.NotContains(t, cspHeader, "connect-src *") + + // Also require all the usual security headers. + requireSecurityHeaders(t, response) +} + +func requireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) { + // Loosely confirm that the generic CSPs were used. + cspHeader := response.Header().Get("Content-Security-Policy") + require.Contains(t, cspHeader, "default-src 'none'") + require.Contains(t, cspHeader, "frame-ancestors 'none'") require.Equal(t, "DENY", response.Header().Get("X-Frame-Options")) require.Equal(t, "1; mode=block", response.Header().Get("X-XSS-Protection"))