diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index fbf13728..bcb8bf1b 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,7 +5,6 @@ package callback import ( - "crypto/subtle" "net/http" "net/url" @@ -14,7 +13,6 @@ import ( "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/oidc" - "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/formposthtml" @@ -102,9 +100,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } - csrfValue, err := readCSRFCookie(r, cookieDecoder) + _, decodedState, err := oidc.ReadStateParamAndValidateCSRFCookie(r, cookieDecoder, stateDecoder) if err != nil { - plog.InfoErr("error reading CSRF cookie", err) + plog.InfoErr("state or CSRF error", err) return nil, err } @@ -113,23 +111,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, httperr.New(http.StatusBadRequest, "code param not found") } - if r.FormValue("state") == "" { - plog.Info("state param not found") - return nil, httperr.New(http.StatusBadRequest, "state param not found") - } - - state, err := readState(r, stateDecoder) - if err != nil { - plog.InfoErr("error reading state", err) - return nil, err - } - - if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 { - plog.InfoErr("CSRF value does not match", err) - return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) - } - - return state, nil + return decodedState, nil } func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCIdentityProvidersLister) provider.UpstreamOIDCIdentityProviderI { @@ -140,36 +122,3 @@ func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCId } return nil } - -func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) { - receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) - if err != nil { - // Error means that the cookie was not found - return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) - } - - var csrfFromCookie csrftoken.CSRFToken - err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) - if err != nil { - return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) - } - - return csrfFromCookie, nil -} - -func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { - var state oidc.UpstreamStateParamData - if err := stateDecoder.Decode( - oidc.UpstreamStateParamEncodingName, - r.FormValue("state"), - &state, - ); err != nil { - return nil, httperr.New(http.StatusBadRequest, "error reading state") - } - - if state.FormatVersion != oidc.UpstreamStateParamFormatVersion { - return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") - } - - return &state, nil -} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 21380979..6fc47773 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -1156,10 +1156,8 @@ func (r *requestPath) String() string { return path + params.Encode() } -type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat - -func happyUpstreamStateParam() *upstreamStateParamBuilder { - return &upstreamStateParamBuilder{ +func happyUpstreamStateParam() *oidctestutil.UpstreamStateParamBuilder { + return &oidctestutil.UpstreamStateParamBuilder{ U: happyUpstreamIDPName, P: happyDownstreamRequestParams, T: "oidc", @@ -1170,37 +1168,6 @@ func happyUpstreamStateParam() *upstreamStateParamBuilder { } } -func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { - state, err := stateEncoder.Encode("s", b) - require.NoError(t, err) - return state -} - -func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder { - b.P = params - return b -} - -func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder { - b.N = nonce - return b -} - -func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder { - b.C = csrf - return b -} - -func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder { - b.K = pkce - return b -} - -func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder { - b.V = version - return b -} - func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). WithName(happyUpstreamIDPName). diff --git a/internal/oidc/login/get_login_handler.go b/internal/oidc/login/get_login_handler.go new file mode 100644 index 00000000..e1d6ffb6 --- /dev/null +++ b/internal/oidc/login/get_login_handler.go @@ -0,0 +1,17 @@ +// Copyright 2022 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package login + +import ( + "net/http" + + "go.pinniped.dev/internal/oidc" +) + +func NewGetHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error { + // TODO + return nil + } +} diff --git a/internal/oidc/login/login_handler.go b/internal/oidc/login/login_handler.go index 10727b3c..a8e65e0e 100644 --- a/internal/oidc/login/login_handler.go +++ b/internal/oidc/login/login_handler.go @@ -5,19 +5,64 @@ package login import ( "net/http" + + idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/httputil/securityheader" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/plog" ) -// NewHandler returns an http.Handler that serves the login endpoint for IDPs that -// don't have their own Web UI. -func NewHandler() http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, `Method not allowed (try GET)`, http.StatusMethodNotAllowed) - return +// HandlerFunc is a function that can handle either a GET or POST request for the login endpoint. +type HandlerFunc func( + w http.ResponseWriter, + r *http.Request, + encodedState string, + decodedState *oidc.UpstreamStateParamData, +) error + +// NewHandler returns a http.Handler that serves the login endpoint for IDPs that don't have their own web UI for login. +// +// This handler takes care of the shared concerns between the GET and POST methods of the login endpoint: +// checking the method, checking the CSRF cookie, decoding the state param, and adding security headers. +// Then it defers the rest of the handling to the passed in handler functions for GET and POST requests. +// Note that CSRF protection isn't needed on GET requests, but it doesn't hurt. Putting it here +// keeps the implementations and tests of HandlerFunc simpler since they won't need to deal with any decoders. +// Users should always initially get redirected to this page from the authorization endpoint, and never need +// to navigate directly to this page in their browser without going through the authorization endpoint first. +// Once their browser has landed on this page, it should be okay for the user to refresh the browser. +func NewHandler( + stateDecoder oidc.Decoder, + cookieDecoder oidc.Decoder, + getHandler HandlerFunc, // use NewGetHandler() for production + postHandler HandlerFunc, // use NewPostHandler() for production +) http.Handler { + loginHandler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + var handler HandlerFunc + switch r.Method { + case http.MethodGet: + handler = getHandler + case http.MethodPost: + handler = postHandler + default: + return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) } - _, err := w.Write([]byte("

hello world

")) + + encodedState, decodedState, err := oidc.ReadStateParamAndValidateCSRFCookie(r, cookieDecoder, stateDecoder) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + plog.InfoErr("state or CSRF error", err) + return err } + + switch decodedState.UpstreamType { + case string(idpdiscoveryv1alpha1.IDPTypeLDAP), string(idpdiscoveryv1alpha1.IDPTypeActiveDirectory): + // these are the types supported by this endpoint, so no error here + default: + return httperr.Newf(http.StatusBadRequest, "not a supported upstream IDP type for this endpoint: %q", decodedState.UpstreamType) + } + + return handler(w, r, encodedState, decodedState) }) + + return securityheader.Wrap(loginHandler) } diff --git a/internal/oidc/login/login_handler_test.go b/internal/oidc/login/login_handler_test.go new file mode 100644 index 00000000..c77758da --- /dev/null +++ b/internal/oidc/login/login_handler_test.go @@ -0,0 +1,448 @@ +// Copyright 2022 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package login + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gorilla/securecookie" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/testutil" + "go.pinniped.dev/internal/testutil/oidctestutil" +) + +func TestLoginEndpoint(t *testing.T) { + const ( + htmlContentType = "text/html; charset=utf-8" + happyGetResult = "

get handler result

" + happyPostResult = "

post handler result

" + + happyUpstreamIDPName = "upstream-idp-name" + happyUpstreamIDPType = "ldap" + happyDownstreamCSRF = "test-csrf" + happyDownstreamPKCE = "test-pkce" + happyDownstreamNonce = "test-nonce" + happyDownstreamStateVersion = "2" + + downstreamClientID = "pinniped-cli" + happyDownstreamState = "8b-state" + downstreamNonce = "some-nonce-value" + downstreamPKCEChallenge = "some-challenge" + downstreamPKCEChallengeMethod = "S256" + downstreamRedirectURI = "http://127.0.0.1/callback" + ) + + happyDownstreamScopesRequested := []string{"openid"} + happyDownstreamRequestParamsQuery := url.Values{ + "response_type": []string{"code"}, + "scope": []string{strings.Join(happyDownstreamScopesRequested, " ")}, + "client_id": []string{downstreamClientID}, + "state": []string{happyDownstreamState}, + "nonce": []string{downstreamNonce}, + "code_challenge": []string{downstreamPKCEChallenge}, + "code_challenge_method": []string{downstreamPKCEChallengeMethod}, + "redirect_uri": []string{downstreamRedirectURI}, + } + happyDownstreamRequestParams := happyDownstreamRequestParamsQuery.Encode() + + expectedHappyDecodedUpstreamStateParam := func() *oidc.UpstreamStateParamData { + return &oidc.UpstreamStateParamData{ + UpstreamName: happyUpstreamIDPName, + UpstreamType: happyUpstreamIDPType, + AuthParams: happyDownstreamRequestParams, + Nonce: happyDownstreamNonce, + CSRFToken: happyDownstreamCSRF, + PKCECode: happyDownstreamPKCE, + FormatVersion: happyDownstreamStateVersion, + } + } + + expectedHappyDecodedUpstreamStateParamForActiveDirectory := func() *oidc.UpstreamStateParamData { + s := expectedHappyDecodedUpstreamStateParam() + s.UpstreamType = "activedirectory" + return s + } + + happyUpstreamStateParam := func() *oidctestutil.UpstreamStateParamBuilder { + return &oidctestutil.UpstreamStateParamBuilder{ + U: happyUpstreamIDPName, + T: happyUpstreamIDPType, + P: happyDownstreamRequestParams, + N: happyDownstreamNonce, + C: happyDownstreamCSRF, + K: happyDownstreamPKCE, + V: happyDownstreamStateVersion, + } + } + + stateEncoderHashKey := []byte("fake-hash-secret") + stateEncoderBlockKey := []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + cookieEncoderHashKey := []byte("fake-hash-secret2") + cookieEncoderBlockKey := []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + happyStateCodec := securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateCodec.SetSerializer(securecookie.JSONEncoder{}) + happyCookieCodec := securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) + + happyState := happyUpstreamStateParam().Build(t, happyStateCodec) + happyPathWithState := newRequestPath().WithState(happyState).String() + + happyActiveDirectoryState := happyUpstreamStateParam().WithUpstreamIDPType("activedirectory").Build(t, happyStateCodec) + + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF) + require.NoError(t, err) + happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + + tests := []struct { + name string + method string + path string + csrfCookie string + getHandlerErr error + postHandlerErr error + + wantStatus int + wantContentType string + wantBody string + wantEncodedState string + wantDecodedState *oidc.UpstreamStateParamData + }{ + { + name: "PUT method is invalid", + method: http.MethodPut, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: PUT (try GET or POST)\n", + }, + { + name: "PATCH method is invalid", + method: http.MethodPatch, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: PATCH (try GET or POST)\n", + }, + { + name: "DELETE method is invalid", + method: http.MethodDelete, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: DELETE (try GET or POST)\n", + }, + { + name: "HEAD method is invalid", + method: http.MethodHead, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: HEAD (try GET or POST)\n", + }, + { + name: "CONNECT method is invalid", + method: http.MethodConnect, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: CONNECT (try GET or POST)\n", + }, + { + name: "OPTIONS method is invalid", + method: http.MethodOptions, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: OPTIONS (try GET or POST)\n", + }, + { + name: "TRACE method is invalid", + method: http.MethodTrace, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusMethodNotAllowed, + wantContentType: htmlContentType, + wantBody: "Method Not Allowed: TRACE (try GET or POST)\n", + }, + { + name: "state param was not included on GET request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not included on POST request", + method: http.MethodPost, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason on GET request", + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: error reading state\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason on POST request", + method: http.MethodPost, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: error reading state\n", + }, + { + name: "the CSRF cookie does not exist on GET request", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: "", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF cookie is missing\n", + }, + { + name: "the CSRF cookie does not exist on POST request", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: "", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF cookie is missing\n", + }, + { + name: "the CSRF cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason on GET request", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "the CSRF cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason on POST request", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "cookie csrf value does not match state csrf value on GET request", + method: http.MethodGet, + path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF value does not match\n", + }, + { + name: "cookie csrf value does not match state csrf value on POST request", + method: http.MethodPost, + path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantContentType: htmlContentType, + wantBody: "Forbidden: CSRF value does not match\n", + }, + { + name: "GET request when upstream IDP type in state param is not supported by this endpoint", + method: http.MethodGet, + path: newRequestPath().WithState( + happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n", + }, + { + name: "POST request when upstream IDP type in state param is not supported by this endpoint", + method: http.MethodPost, + path: newRequestPath().WithState( + happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec), + ).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantContentType: htmlContentType, + wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n", + }, + { + name: "valid GET request when GET endpoint handler returns an error", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + getHandlerErr: httperr.Newf(http.StatusInternalServerError, "some get error"), + wantStatus: http.StatusInternalServerError, + wantContentType: htmlContentType, + wantBody: "Internal Server Error: some get error\n", + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "valid POST request when POST endpoint handler returns an error", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + postHandlerErr: httperr.Newf(http.StatusInternalServerError, "some post error"), + wantStatus: http.StatusInternalServerError, + wantContentType: htmlContentType, + wantBody: "Internal Server Error: some post error\n", + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "happy GET request for LDAP upstream", + method: http.MethodGet, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyGetResult, + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "happy POST request for LDAP upstream", + method: http.MethodPost, + path: happyPathWithState, + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyPostResult, + wantEncodedState: happyState, + wantDecodedState: expectedHappyDecodedUpstreamStateParam(), + }, + { + name: "happy GET request for ActiveDirectory upstream", + method: http.MethodGet, + path: newRequestPath().WithState(happyActiveDirectoryState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyGetResult, + wantEncodedState: happyActiveDirectoryState, + wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), + }, + { + name: "happy POST request for ActiveDirectory upstream", + method: http.MethodPost, + path: newRequestPath().WithState(happyActiveDirectoryState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusOK, + wantContentType: htmlContentType, + wantBody: happyPostResult, + wantEncodedState: happyActiveDirectoryState, + wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(), + }, + } + + for _, test := range tests { + tt := test + + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + if test.csrfCookie != "" { + req.Header.Set("Cookie", test.csrfCookie) + } + rsp := httptest.NewRecorder() + + testGetHandler := func( + w http.ResponseWriter, + r *http.Request, + encodedState string, + decodedState *oidc.UpstreamStateParamData, + ) error { + require.Equal(t, req, r) + require.Equal(t, rsp, w) + require.Equal(t, tt.wantEncodedState, encodedState) + require.Equal(t, tt.wantDecodedState, decodedState) + if tt.getHandlerErr == nil { + _, err := w.Write([]byte(happyGetResult)) + require.NoError(t, err) + } + return tt.getHandlerErr + } + + testPostHandler := func( + w http.ResponseWriter, + r *http.Request, + encodedState string, + decodedState *oidc.UpstreamStateParamData, + ) error { + require.Equal(t, req, r) + require.Equal(t, rsp, w) + require.Equal(t, tt.wantEncodedState, encodedState) + require.Equal(t, tt.wantDecodedState, decodedState) + if tt.postHandlerErr == nil { + _, err := w.Write([]byte(happyPostResult)) + require.NoError(t, err) + } + return tt.postHandlerErr + } + + subject := NewHandler(happyStateCodec, happyCookieCodec, testGetHandler, testPostHandler) + + subject.ServeHTTP(rsp, req) + + testutil.RequireSecurityHeaders(t, rsp) + + require.Equal(t, tt.wantStatus, rsp.Code) + testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType) + require.Equal(t, tt.wantBody, rsp.Body.String()) + }) + } +} + +type requestPath struct { + state *string +} + +func newRequestPath() *requestPath { + return &requestPath{} +} + +func (r *requestPath) WithState(state string) *requestPath { + r.state = &state + return r +} + +func (r *requestPath) WithoutState() *requestPath { + r.state = nil + return r +} + +func (r *requestPath) String() string { + path := "/login?" + params := url.Values{} + if r.state != nil { + params.Add("state", *r.state) + } + return path + params.Encode() +} diff --git a/internal/oidc/login/post_login_handler.go b/internal/oidc/login/post_login_handler.go new file mode 100644 index 00000000..33819c69 --- /dev/null +++ b/internal/oidc/login/post_login_handler.go @@ -0,0 +1,19 @@ +// Copyright 2022 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package login + +import ( + "net/http" + + "github.com/ory/fosite" + + "go.pinniped.dev/internal/oidc" +) + +func NewPostHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error { + // TODO + return nil + } +} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 9467eb22..90c47655 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -5,12 +5,15 @@ package oidc import ( + "crypto/subtle" + "net/http" "time" coreosoidc "github.com/coreos/go-oidc/v3/oidc" "github.com/ory/fosite" "github.com/ory/fosite/compose" + "go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/provider" @@ -297,3 +300,68 @@ func ScopeWasRequested(authorizeRequester fosite.AuthorizeRequester, scopeName s } return false } + +func ReadStateParamAndValidateCSRFCookie(r *http.Request, cookieDecoder Decoder, stateDecoder Decoder) (string, *UpstreamStateParamData, error) { + csrfValue, err := readCSRFCookie(r, cookieDecoder) + if err != nil { + return "", nil, err + } + + encodedState, decodedState, err := readStateParam(r, stateDecoder) + if err != nil { + return "", nil, err + } + + err = validateCSRFValue(decodedState, csrfValue) + if err != nil { + return "", nil, err + } + + return encodedState, decodedState, nil +} + +func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) { + receivedCSRFCookie, err := r.Cookie(CSRFCookieName) + if err != nil { + // Error means that the cookie was not found + return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) + } + + var csrfFromCookie csrftoken.CSRFToken + err = cookieDecoder.Decode(CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + if err != nil { + return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) + } + + return csrfFromCookie, nil +} + +func readStateParam(r *http.Request, stateDecoder Decoder) (string, *UpstreamStateParamData, error) { + encodedState := r.FormValue("state") + + if encodedState == "" { + return "", nil, httperr.New(http.StatusBadRequest, "state param not found") + } + + var state UpstreamStateParamData + if err := stateDecoder.Decode( + UpstreamStateParamEncodingName, + r.FormValue("state"), + &state, + ); err != nil { + return "", nil, httperr.New(http.StatusBadRequest, "error reading state") + } + + if state.FormatVersion != UpstreamStateParamFormatVersion { + return "", nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") + } + + return encodedState, &state, nil +} + +func validateCSRFValue(state *UpstreamStateParamData, csrfCookieValue csrftoken.CSRFToken) error { + if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfCookieValue)) != 1 { + return httperr.New(http.StatusForbidden, "CSRF value does not match") + } + return nil +} diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index dda7fa86..283b1808 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -136,7 +136,12 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs oauthHelperWithKubeStorage, ) - m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler() + m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler( + upstreamStateEncoder, + csrfCookieEncoder, + login.NewGetHandler(m.upstreamIDPs), + login.NewPostHandler(m.upstreamIDPs, oauthHelperWithKubeStorage), + ) plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) } diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index 1936c406..c408ada9 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -15,6 +15,7 @@ import ( "time" coreosoidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/gorilla/securecookie" "github.com/ory/fosite" "github.com/stretchr/testify/require" "golang.org/x/oauth2" @@ -837,6 +838,44 @@ type ExpectedUpstreamStateParamFormat struct { V string `json:"v"` } +type UpstreamStateParamBuilder ExpectedUpstreamStateParamFormat + +func (b UpstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { + state, err := stateEncoder.Encode("s", b) + require.NoError(t, err) + return state +} + +func (b *UpstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *UpstreamStateParamBuilder { + b.P = params + return b +} + +func (b *UpstreamStateParamBuilder) WithNonce(nonce string) *UpstreamStateParamBuilder { + b.N = nonce + return b +} + +func (b *UpstreamStateParamBuilder) WithCSRF(csrf string) *UpstreamStateParamBuilder { + b.C = csrf + return b +} + +func (b *UpstreamStateParamBuilder) WithPKCE(pkce string) *UpstreamStateParamBuilder { + b.K = pkce + return b +} + +func (b *UpstreamStateParamBuilder) WithUpstreamIDPType(upstreamIDPType string) *UpstreamStateParamBuilder { + b.T = upstreamIDPType + return b +} + +func (b *UpstreamStateParamBuilder) WithStateVersion(version string) *UpstreamStateParamBuilder { + b.V = version + return b +} + type staticKeySet struct { publicKey crypto.PublicKey }