From 65eed7e74282b08d257b30697ba0f5e301e24f5f Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Tue, 26 Apr 2022 15:30:39 -0700 Subject: [PATCH] Implement login_handler.go to defer to other handlers The other handlers for GET and POST requests are not yet implemented in this commit. The shared handler code in login_handler.go takes care of things checking the method, checking the CSRF cookie, decoding the state param, and adding security headers on behalf of both the GET and POST handlers. Some code has been extracted from callback_handler.go to be shared. --- internal/oidc/callback/callback_handler.go | 57 +-- .../oidc/callback/callback_handler_test.go | 37 +- internal/oidc/login/get_login_handler.go | 17 + internal/oidc/login/login_handler.go | 63 ++- internal/oidc/login/login_handler_test.go | 448 ++++++++++++++++++ internal/oidc/login/post_login_handler.go | 19 + internal/oidc/oidc.go | 68 +++ internal/oidc/provider/manager/manager.go | 7 +- .../testutil/oidctestutil/oidctestutil.go | 39 ++ 9 files changed, 656 insertions(+), 99 deletions(-) create mode 100644 internal/oidc/login/get_login_handler.go create mode 100644 internal/oidc/login/login_handler_test.go create mode 100644 internal/oidc/login/post_login_handler.go diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index fbf13728..bcb8bf1b 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,7 +5,6 @@ package callback import ( - "crypto/subtle" "net/http" "net/url" @@ -14,7 +13,6 @@ import ( "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/oidc" - "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/formposthtml" @@ -102,9 +100,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } - csrfValue, err := readCSRFCookie(r, cookieDecoder) + _, decodedState, err := oidc.ReadStateParamAndValidateCSRFCookie(r, cookieDecoder, stateDecoder) if err != nil { - plog.InfoErr("error reading CSRF cookie", err) + plog.InfoErr("state or CSRF error", err) return nil, err } @@ -113,23 +111,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, httperr.New(http.StatusBadRequest, "code param not found") } - if r.FormValue("state") == "" { - plog.Info("state param not found") - return nil, httperr.New(http.StatusBadRequest, "state param not found") - } - - state, err := readState(r, stateDecoder) - if err != nil { - plog.InfoErr("error reading state", err) - return nil, err - } - - if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 { - plog.InfoErr("CSRF value does not match", err) - return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) - } - - return state, nil + return decodedState, nil } func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCIdentityProvidersLister) provider.UpstreamOIDCIdentityProviderI { @@ -140,36 +122,3 @@ func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCId } return nil } - -func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) { - receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) - if err != nil { - // Error means that the cookie was not found - return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) - } - - var csrfFromCookie csrftoken.CSRFToken - err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) - if err != nil { - return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) - } - - return csrfFromCookie, nil -} - -func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { - var state oidc.UpstreamStateParamData - if err := stateDecoder.Decode( - oidc.UpstreamStateParamEncodingName, - r.FormValue("state"), - &state, - ); err != nil { - return nil, httperr.New(http.StatusBadRequest, "error reading state") - } - - if state.FormatVersion != oidc.UpstreamStateParamFormatVersion { - return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") - } - - return &state, nil -} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 21380979..6fc47773 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -1156,10 +1156,8 @@ func (r *requestPath) String() string { return path + params.Encode() } -type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat - -func happyUpstreamStateParam() *upstreamStateParamBuilder { - return &upstreamStateParamBuilder{ +func happyUpstreamStateParam() *oidctestutil.UpstreamStateParamBuilder { + return &oidctestutil.UpstreamStateParamBuilder{ U: happyUpstreamIDPName, P: happyDownstreamRequestParams, T: "oidc", @@ -1170,37 +1168,6 @@ func happyUpstreamStateParam() *upstreamStateParamBuilder { } } -func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { - state, err := stateEncoder.Encode("s", b) - require.NoError(t, err) - return state -} - -func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder { - b.P = params - return b -} - -func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder { - b.N = nonce - return b -} - -func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder { - b.C = csrf - return b -} - -func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder { - b.K = pkce - return b -} - -func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder { - b.V = version - return b -} - func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName(happyUpstreamIDPName). diff --git a/internal/oidc/login/get_login_handler.go b/internal/oidc/login/get_login_handler.go new file mode 100644 index 00000000..e1d6ffb6 --- /dev/null +++ b/internal/oidc/login/get_login_handler.go @@ -0,0 +1,17 @@ +// Copyright 2022 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package login + +import ( + "net/http" + + "go.pinniped.dev/internal/oidc" +) + +func NewGetHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error { + // TODO + return nil + } +} diff --git a/internal/oidc/login/login_handler.go b/internal/oidc/login/login_handler.go index 10727b3c..a8e65e0e 100644 --- a/internal/oidc/login/login_handler.go +++ b/internal/oidc/login/login_handler.go @@ -5,19 +5,64 @@ package login import ( "net/http" + + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/httputil/securityheader" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/plog" ) -// NewHandler returns an http.Handler that serves the login endpoint for IDPs that -// don't have their own Web UI. -func NewHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, `Method not allowed (try GET)`, http.StatusMethodNotAllowed) - return +// HandlerFunc is a function that can handle either a GET or POST request for the login endpoint. +type HandlerFunc func( + w http.ResponseWriter, + r *http.Request, + encodedState string, + decodedState *oidc.UpstreamStateParamData, +) error + +// NewHandler returns a http.Handler that serves the login endpoint for IDPs that don't have their own web UI for login. +// +// This handler takes care of the shared concerns between the GET and POST methods of the login endpoint: +// checking the method, checking the CSRF cookie, decoding the state param, and adding security headers. +// Then it defers the rest of the handling to the passed in handler functions for GET and POST requests. +// Note that CSRF protection isn't needed on GET requests, but it doesn't hurt. Putting it here +// keeps the implementations and tests of HandlerFunc simpler since they won't need to deal with any decoders. +// Users should always initially get redirected to this page from the authorization endpoint, and never need +// to navigate directly to this page in their browser without going through the authorization endpoint first. +// Once their browser has landed on this page, it should be okay for the user to refresh the browser. +func NewHandler( + stateDecoder oidc.Decoder, + cookieDecoder oidc.Decoder, + getHandler HandlerFunc, // use NewGetHandler() for production + postHandler HandlerFunc, // use NewPostHandler() for production +) http.Handler { + loginHandler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + var handler HandlerFunc + switch r.Method { + case http.MethodGet: + handler = getHandler + case http.MethodPost: + handler = postHandler + default: + return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) } - _, err := w.Write([]byte("

hello world

")) + + encodedState, decodedState, err := oidc.ReadStateParamAndValidateCSRFCookie(r, cookieDecoder, stateDecoder) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + plog.InfoErr("state or CSRF error", err) + return err } + + switch decodedState.UpstreamType { + case string(idpdiscoveryv1alpha1.IDPTypeLDAP), string(idpdiscoveryv1alpha1.IDPTypeActiveDirectory): + // these are the types supported by this endpoint, so no error here + default: + return httperr.Newf(http.StatusBadRequest, "not a supported upstream IDP type for this endpoint: %q", decodedState.UpstreamType) + } + + return handler(w, r, encodedState, decodedState) }) + + return securityheader.Wrap(loginHandler) } diff --git a/internal/oidc/login/login_handler_test.go b/internal/oidc/login/login_handler_test.go new file mode 100644 index 00000000..c77758da --- /dev/null +++ b/internal/oidc/login/login_handler_test.go @@ -0,0 +1,448 @@ +// Copyright 2022 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package login + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gorilla/securecookie" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/internal/testutil/oidctestutil" +) + +func TestLoginEndpoint(t *testing.T) { + const ( + htmlContentType = "text/html; charset=utf-8" + happyGetResult = "

get handler result

" + happyPostResult = "

post handler result

" + + happyUpstreamIDPName = "upstream-idp-name" + happyUpstreamIDPType = "ldap" + happyDownstreamCSRF = "test-csrf" + happyDownstreamPKCE = "test-pkce" + happyDownstreamNonce = "test-nonce" + happyDownstreamStateVersion = "2" + + downstreamClientID = "pinniped-cli" + happyDownstreamState = "8b-state" + downstreamNonce = "some-nonce-value" + downstreamPKCEChallenge = "some-challenge" + downstreamPKCEChallengeMethod = "S256" + downstreamRedirectURI = "http://127.0.0.1/callback" + ) + + happyDownstreamScopesRequested := []string{"openid"} + happyDownstreamRequestParamsQuery := url.Values{ + "response_type": []string{"code"}, + "scope": []string{strings.Join(happyDownstreamScopesRequested, " ")}, + "client_id": []string{downstreamClientID}, + "state": []string{happyDownstreamState}, + "nonce": []string{downstreamNonce}, + "code_challenge": []string{downstreamPKCEChallenge}, + "code_challenge_method": []string{downstreamPKCEChallengeMethod}, + "redirect_uri": []string{downstreamRedirectURI}, + } + happyDownstreamRequestParams := happyDownstreamRequestParamsQuery.Encode() + + expectedHappyDecodedUpstreamStateParam := func() *oidc.UpstreamStateParamData { + return &oidc.UpstreamStateParamData{ + UpstreamName: happyUpstreamIDPName, + UpstreamType: happyUpstreamIDPType, + AuthParams: happyDownstreamRequestParams, + Nonce: happyDownstreamNonce, + CSRFToken: happyDownstreamCSRF, + PKCECode: happyDownstreamPKCE, + FormatVersion: happyDownstreamStateVersion, + } + } + + expectedHappyDecodedUpstreamStateParamForActiveDirectory := func() *oidc.UpstreamStateParamData { + s := expectedHappyDecodedUpstreamStateParam() + s.UpstreamType = "activedirectory" + return s + } + + happyUpstreamStateParam := func() *oidctestutil.UpstreamStateParamBuilder { + return &oidctestutil.UpstreamStateParamBuilder{ + U: happyUpstreamIDPName, + T: happyUpstreamIDPType, + P: happyDownstreamRequestParams, + N: happyDownstreamNonce, + C: happyDownstreamCSRF, + K: happyDownstreamPKCE, + V: happyDownstreamStateVersion, + } + } + + stateEncoderHashKey := []byte("fake-hash-secret") + stateEncoderBlockKey := []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + cookieEncoderHashKey := []byte("fake-hash-secret2") + cookieEncoderBlockKey := []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + happyStateCodec := securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateCodec.SetSerializer(securecookie.JSONEncoder{}) + happyCookieCodec := securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) + + happyState := happyUpstreamStateParam().Build(t, happyStateCodec) + happyPathWithState := newRequestPath().WithState(happyState).String() + + happyActiveDirectoryState := happyUpstreamStateParam().WithUpstreamIDPType("activedirectory").Build(t, happyStateCodec) + + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) + require.NoError(t, err) + happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + + tests := []struct { + name string + method string + path string + csrfCookie string + getHandlerErr error + postHandlerErr error + + wantStatus int + wantContentType string + wantBody string + wantEncodedState string + wantDecodedState *oidc.UpstreamStateParamData + }{ + { + name: "PUT method is invalid", + method: http.MethodPut, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: PUT (try GET or POST)\n", + }, + { + name: "PATCH method is invalid", + method: http.MethodPatch, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: PATCH (try GET or POST)\n", + }, + { + name: "DELETE method is invalid", + method: http.MethodDelete, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: DELETE (try GET or POST)\n", + }, + { + name: "HEAD method is invalid", + method: http.MethodHead, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: HEAD (try GET or POST)\n", + }, + { + name: "CONNECT method is invalid", + method: http.MethodConnect, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: CONNECT (try GET or POST)\n", + }, + { + name: "OPTIONS method is invalid", + method: http.MethodOptions, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: OPTIONS (try GET or POST)\n", + }, + { + name: "TRACE method is invalid", + method: http.MethodTrace, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: TRACE (try GET or POST)\n", + }, + { + name: "state param was not included on GET request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not included on POST request", + method: http.MethodPost, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason on GET request", + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: error reading state\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason on POST request", + method: http.MethodPost, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: error reading state\n", + }, + { + name: "the CSRF cookie does not exist on GET request", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: "", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF cookie is missing\n", + }, + { + name: "the CSRF cookie does not exist on POST request", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: "", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF cookie is missing\n", + }, + { + name: "the CSRF cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason on GET request", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "the CSRF cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason on POST request", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "cookie csrf value does not match state csrf value on GET request", + method: http.MethodGet, + path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF value does not match\n", + }, + { + name: "cookie csrf value does not match state csrf value on POST request", + method: http.MethodPost, + path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF value does not match\n", + }, + { + name: "GET request when upstream IDP type in state param is not supported by this endpoint", + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n", + }, + { + name: "POST request when upstream IDP type in state param is not supported by this endpoint", + method: http.MethodPost, + path: newRequestPath().WithState( + happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n", + }, + { + name: "valid GET request when GET endpoint handler returns an error", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + getHandlerErr: httperr.Newf(http.StatusInternalServerError, "some get error"), + wantStatus: http.StatusInternalServerError, + wantContentType: htmlContentType, + wantBody: "Internal Server Error: some get error\n", + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "valid POST request when POST endpoint handler returns an error", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + postHandlerErr: httperr.Newf(http.StatusInternalServerError, "some post error"), + wantStatus: http.StatusInternalServerError, + wantContentType: htmlContentType, + wantBody: "Internal Server Error: some post error\n", + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "happy GET request for LDAP upstream", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyGetResult, + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "happy POST request for LDAP upstream", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyPostResult, + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "happy GET request for ActiveDirectory upstream", + method: http.MethodGet, + path: newRequestPath().WithState(happyActiveDirectoryState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyGetResult, + wantEncodedState: happyActiveDirectoryState, + wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), + }, + { + name: "happy POST request for ActiveDirectory upstream", + method: http.MethodPost, + path: newRequestPath().WithState(happyActiveDirectoryState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyPostResult, + wantEncodedState: happyActiveDirectoryState, + wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), + }, + } + + for _, test := range tests { + tt := test + + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + if test.csrfCookie != "" { + req.Header.Set("Cookie", test.csrfCookie) + } + rsp := httptest.NewRecorder() + + testGetHandler := func( + w http.ResponseWriter, + r *http.Request, + encodedState string, + decodedState *oidc.UpstreamStateParamData, + ) error { + require.Equal(t, req, r) + require.Equal(t, rsp, w) + require.Equal(t, tt.wantEncodedState, encodedState) + require.Equal(t, tt.wantDecodedState, decodedState) + if tt.getHandlerErr == nil { + _, err := w.Write([]byte(happyGetResult)) + require.NoError(t, err) + } + return tt.getHandlerErr + } + + testPostHandler := func( + w http.ResponseWriter, + r *http.Request, + encodedState string, + decodedState *oidc.UpstreamStateParamData, + ) error { + require.Equal(t, req, r) + require.Equal(t, rsp, w) + require.Equal(t, tt.wantEncodedState, encodedState) + require.Equal(t, tt.wantDecodedState, decodedState) + if tt.postHandlerErr == nil { + _, err := w.Write([]byte(happyPostResult)) + require.NoError(t, err) + } + return tt.postHandlerErr + } + + subject := NewHandler(happyStateCodec, happyCookieCodec, testGetHandler, testPostHandler) + + subject.ServeHTTP(rsp, req) + + testutil.RequireSecurityHeaders(t, rsp) + + require.Equal(t, tt.wantStatus, rsp.Code) + testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) + require.Equal(t, tt.wantBody, rsp.Body.String()) + }) + } +} + +type requestPath struct { + state *string +} + +func newRequestPath() *requestPath { + return &requestPath{} +} + +func (r *requestPath) WithState(state string) *requestPath { + r.state = &state + return r +} + +func (r *requestPath) WithoutState() *requestPath { + r.state = nil + return r +} + +func (r *requestPath) String() string { + path := "/login?" + params := url.Values{} + if r.state != nil { + params.Add("state", *r.state) + } + return path + params.Encode() +} diff --git a/internal/oidc/login/post_login_handler.go b/internal/oidc/login/post_login_handler.go new file mode 100644 index 00000000..33819c69 --- /dev/null +++ b/internal/oidc/login/post_login_handler.go @@ -0,0 +1,19 @@ +// Copyright 2022 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package login + +import ( + "net/http" + + "github.com/ory/fosite" + + "go.pinniped.dev/internal/oidc" +) + +func NewPostHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error { + // TODO + return nil + } +} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 9467eb22..90c47655 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -5,12 +5,15 @@ package oidc import ( + "crypto/subtle" + "net/http" "time" coreosoidc "github.com/coreos/go-oidc/v3/oidc" "github.com/ory/fosite" "github.com/ory/fosite/compose" + "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/provider" @@ -297,3 +300,68 @@ func ScopeWasRequested(authorizeRequester fosite.AuthorizeRequester, scopeName s } return false } + +func ReadStateParamAndValidateCSRFCookie(r *http.Request, cookieDecoder Decoder, stateDecoder Decoder) (string, *UpstreamStateParamData, error) { + csrfValue, err := readCSRFCookie(r, cookieDecoder) + if err != nil { + return "", nil, err + } + + encodedState, decodedState, err := readStateParam(r, stateDecoder) + if err != nil { + return "", nil, err + } + + err = validateCSRFValue(decodedState, csrfValue) + if err != nil { + return "", nil, err + } + + return encodedState, decodedState, nil +} + +func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) { + receivedCSRFCookie, err := r.Cookie(CSRFCookieName) + if err != nil { + // Error means that the cookie was not found + return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) + } + + var csrfFromCookie csrftoken.CSRFToken + err = cookieDecoder.Decode(CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + if err != nil { + return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) + } + + return csrfFromCookie, nil +} + +func readStateParam(r *http.Request, stateDecoder Decoder) (string, *UpstreamStateParamData, error) { + encodedState := r.FormValue("state") + + if encodedState == "" { + return "", nil, httperr.New(http.StatusBadRequest, "state param not found") + } + + var state UpstreamStateParamData + if err := stateDecoder.Decode( + UpstreamStateParamEncodingName, + r.FormValue("state"), + &state, + ); err != nil { + return "", nil, httperr.New(http.StatusBadRequest, "error reading state") + } + + if state.FormatVersion != UpstreamStateParamFormatVersion { + return "", nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") + } + + return encodedState, &state, nil +} + +func validateCSRFValue(state *UpstreamStateParamData, csrfCookieValue csrftoken.CSRFToken) error { + if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfCookieValue)) != 1 { + return httperr.New(http.StatusForbidden, "CSRF value does not match") + } + return nil +} diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index dda7fa86..283b1808 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -136,7 +136,12 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs oauthHelperWithKubeStorage, ) - m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler() + m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler( + upstreamStateEncoder, + csrfCookieEncoder, + login.NewGetHandler(m.upstreamIDPs), + login.NewPostHandler(m.upstreamIDPs, oauthHelperWithKubeStorage), + ) plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) } diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 1936c406..c408ada9 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -15,6 +15,7 @@ import ( "time" coreosoidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/securecookie" "github.com/ory/fosite" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -837,6 +838,44 @@ type ExpectedUpstreamStateParamFormat struct { V string `json:"v"` } +type UpstreamStateParamBuilder ExpectedUpstreamStateParamFormat + +func (b UpstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { + state, err := stateEncoder.Encode("s", b) + require.NoError(t, err) + return state +} + +func (b *UpstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *UpstreamStateParamBuilder { + b.P = params + return b +} + +func (b *UpstreamStateParamBuilder) WithNonce(nonce string) *UpstreamStateParamBuilder { + b.N = nonce + return b +} + +func (b *UpstreamStateParamBuilder) WithCSRF(csrf string) *UpstreamStateParamBuilder { + b.C = csrf + return b +} + +func (b *UpstreamStateParamBuilder) WithPKCE(pkce string) *UpstreamStateParamBuilder { + b.K = pkce + return b +} + +func (b *UpstreamStateParamBuilder) WithUpstreamIDPType(upstreamIDPType string) *UpstreamStateParamBuilder { + b.T = upstreamIDPType + return b +} + +func (b *UpstreamStateParamBuilder) WithStateVersion(version string) *UpstreamStateParamBuilder { + b.V = version + return b +} + type staticKeySet struct { publicKey crypto.PublicKey }