diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 96cb4464..18c5e98e 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -9,8 +9,6 @@ import ( "net/http" "time" - "github.com/gorilla/securecookie" - "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" @@ -25,22 +23,6 @@ import ( "go.pinniped.dev/internal/plog" ) -const ( - // Just in case we need to make a breaking change to the format of the upstream state param, - // we are including a format version number. This gives the opportunity for a future version of Pinniped - // to have the consumer of this format decide to reject versions that it doesn't understand. - upstreamStateParamFormatVersion = "1" - - // The `name` passed to the encoder for encoding the upstream state param value. This name is short - // because it will be encoded into the upstream state param value and we're trying to keep that small. - upstreamStateParamEncodingName = "s" -) - -// Encoder is the encoding side of the securecookie.Codec interface. -type Encoder interface { - Encode(name string, value interface{}) (string, error) -} - func NewHandler( issuer string, idpListGetter oidc.IDPListGetter, @@ -48,8 +30,8 @@ func NewHandler( generateCSRF func() (csrftoken.CSRFToken, error), generatePKCE func() (pkce.Code, error), generateNonce func() (nonce.Nonce, error), - upstreamStateEncoder Encoder, - cookieCodec securecookie.Codec, + upstreamStateEncoder oidc.Encoder, + cookieCodec oidc.Codec, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodPost && r.Method != http.MethodGet { @@ -144,7 +126,7 @@ func NewHandler( }) } -func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { +func readCSRFCookie(r *http.Request, codec oidc.Codec) (csrftoken.CSRFToken, error) { receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found @@ -204,37 +186,28 @@ func generateValues( return csrfValue, nonceValue, pkceValue, nil } -// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param. -type upstreamStateParamData struct { - AuthParams string `json:"p"` - Nonce nonce.Nonce `json:"n"` - CSRFToken csrftoken.CSRFToken `json:"c"` - PKCECode pkce.Code `json:"k"` - StateParamFormatVersion string `json:"v"` -} - func upstreamStateParam( authorizeRequester fosite.AuthorizeRequester, nonceValue nonce.Nonce, csrfValue csrftoken.CSRFToken, pkceValue pkce.Code, - encoder Encoder, + encoder oidc.Encoder, ) (string, error) { - stateParamData := upstreamStateParamData{ - AuthParams: authorizeRequester.GetRequestForm().Encode(), - Nonce: nonceValue, - CSRFToken: csrfValue, - PKCECode: pkceValue, - StateParamFormatVersion: upstreamStateParamFormatVersion, + stateParamData := oidc.UpstreamStateParamData{ + AuthParams: authorizeRequester.GetRequestForm().Encode(), + Nonce: nonceValue, + CSRFToken: csrfValue, + PKCECode: pkceValue, + FormatVersion: oidc.UpstreamStateParamFormatVersion, } - encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData) + encodedStateParamValue, err := encoder.Encode(oidc.UpstreamStateParamEncodingName, stateParamData) if err != nil { return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err) } return encodedStateParamValue, nil } -func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { +func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec oidc.Codec) error { encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue) if err != nil { return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 173021d8..6f38a8d9 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -249,8 +249,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF func() (csrftoken.CSRFToken, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) - stateEncoder securecookie.Codec - cookieEncoder securecookie.Codec + stateEncoder oidc.Codec + cookieEncoder oidc.Codec method string path string contentType string @@ -807,7 +807,7 @@ func TestAuthorizationEndpoint(t *testing.T) { } type errorReturningEncoder struct { - securecookie.Codec + oidc.Codec } func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) { @@ -830,7 +830,7 @@ func requireEqualContentType(t *testing.T, actual string, expected string) { require.Equal(t, actualContentTypeParams, expectedContentTypeParams) } -func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) { +func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) { t.Helper() actualLocationURL, err := url.Parse(actualURL) require.NoError(t, err) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 9c5ad6fd..5bd04127 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -12,43 +12,62 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/plog" ) -// Decoder is the decoding side of the securecookie.Codec interface. -type Decoder interface { - Decode(name, value string, into interface{}) error -} - func NewHandler( idpListGetter oidc.IDPListGetter, - cookieDecoder Decoder, + stateDecoder, cookieDecoder oidc.Decoder, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) - } - - _, err := readCSRFCookie(r, cookieDecoder) - if err != nil { + if err := validateRequest(r, stateDecoder, cookieDecoder); err != nil { return err } - if r.FormValue("code") == "" { - return httperr.New(http.StatusBadRequest, "code param not found") - } - - if r.FormValue("state") == "" { - return httperr.New(http.StatusBadRequest, "state param not found") - } - if findUpstreamIDPConfig(r, idpListGetter) == nil { + plog.Warning("upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") } - return httperr.New(http.StatusBadRequest, "state param not valid") + return nil }) } +func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) error { + if r.Method != http.MethodGet { + return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) + } + + csrfValue, err := readCSRFCookie(r, cookieDecoder) + if err != nil { + plog.InfoErr("error reading CSRF cookie", err) + return err + } + + if r.FormValue("code") == "" { + plog.Info("code param not found") + return httperr.New(http.StatusBadRequest, "code param not found") + } + + if r.FormValue("state") == "" { + plog.Info("state param not found") + return httperr.New(http.StatusBadRequest, "state param not found") + } + + state, err := readState(r, stateDecoder) + if err != nil { + plog.InfoErr("error reading state", err) + return err + } + + if state.CSRFToken != csrfValue { + plog.InfoErr("CSRF value does not match", err) + return httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) + } + + return nil +} + func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { _, lastPathComponent := path.Split(r.URL.Path) for _, p := range idpListGetter.GetIDPList() { @@ -59,18 +78,35 @@ func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *p return nil } -func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) { +func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) { receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName) if err != nil { // Error means that the cookie was not found - return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err) + return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err) } var csrfFromCookie csrftoken.CSRFToken err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) if err != nil { - return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err) + return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err) } return csrfFromCookie, nil } + +func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { + var state oidc.UpstreamStateParamData + if err := stateDecoder.Decode( + oidc.UpstreamStateParamEncodingName, + r.FormValue("state"), + &state, + ); err != nil { + return nil, httperr.New(http.StatusBadRequest, "error reading state") + } + + if state.FormatVersion != oidc.UpstreamStateParamFormatVersion { + return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid") + } + + return &state, nil +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 8d966db4..46dff9e2 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -48,28 +48,49 @@ func TestCallbackEndpoint(t *testing.T) { require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) - var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) - happyStateEncoder.SetSerializer(securecookie.JSONEncoder{}) - var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) - happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{}) + var happyStateCodec = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateCodec.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) - // happyCSRF := "test-csrf" - // happyPKCE := "test-pkce" - // happyNonce := "test-nonce" - // - // happyEncodedState, err := happyStateEncoder.Encode("s", - // testutil.ExpectedUpstreamStateParamFormat{ - // P: "todo query goes here", - // N: happyNonce, - // C: happyCSRF, - // K: happyPKCE, - // V: "1", - // }, - // ) - // require.NoError(t, err) + happyCSRF := "test-csrf" + happyPKCE := "test-pkce" + happyNonce := "test-nonce" - incomingCookieCSRFValue := "csrf-value-from-cookie" - encodedIncomingCookieCSRFValue, err := happyCookieEncoder.Encode("csrf", incomingCookieCSRFValue) + happyState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "todo query goes here", + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: "1", + }, + ) + require.NoError(t, err) + + wrongCSRFValueState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "todo query goes here", + N: happyNonce, + C: "wrong-csrf-value", + K: happyPKCE, + V: "1", + }, + ) + require.NoError(t, err) + + wrongVersionState, err := happyStateCodec.Encode("s", + testutil.ExpectedUpstreamStateParamFormat{ + P: "todo query goes here", + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: "wrong-version", + }, + ) + require.NoError(t, err) + + encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyCSRF) require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue @@ -77,7 +98,6 @@ func TestCallbackEndpoint(t *testing.T) { name string idpListGetter provider.DynamicUpstreamIDPProvider - cookieDecoder Decoder method string path string csrfCookie string @@ -118,39 +138,44 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Method Not Allowed: DELETE (try GET)\n", }, { - name: "code param was not included on request", - cookieDecoder: happyCookieEncoder, - method: http.MethodGet, - path: newRequestPath().WithoutCode().String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: code param not found\n", + name: "code param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithoutCode().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: code param not found\n", }, { - name: "state param was not included on request", - cookieDecoder: happyCookieEncoder, - method: http.MethodGet, - path: newRequestPath().WithoutState().String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: state param not found\n", + name: "state param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not found\n", }, { name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, path: newRequestPath().WithState("this-will-not-decode").String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: state param not valid\n", + wantBody: "Bad Request: error reading state\n", + }, + { + name: "state's internal version does not match what we want", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(wrongVersionState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: state format version is invalid\n", }, { name: "the UpstreamOIDCProvider CRD has been deleted", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, - path: newRequestPath().String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusUnprocessableEntity, wantBody: "Unprocessable Entity: upstream provider not found\n", @@ -158,24 +183,29 @@ func TestCallbackEndpoint(t *testing.T) { { name: "the CSRF cookie does not exist on request", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, - path: newRequestPath().String(), + path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, - wantBody: "Forbidden: unauthorized request\n", + wantBody: "Forbidden: CSRF cookie is missing\n", }, { - name: "the CSRF cookie cannot be decoded", + name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - cookieDecoder: happyCookieEncoder, method: http.MethodGet, - path: newRequestPath().String(), + path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", wantStatus: http.StatusForbidden, - wantBody: "Forbidden: unauthorized request\n", + wantBody: "Forbidden: error reading CSRF cookie\n", + }, + { + name: "cookie csrf value does not match state csrf value", + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(wrongCSRFValueState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF value does not match\n", }, - // TODO: csrf value from inside state param does not match csrf cookie value - // TODO: state's internal version does not match what we want // Upstream exchange // TODO: network call to upstream token endpoint fails @@ -201,7 +231,7 @@ func TestCallbackEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler(test.idpListGetter, test.cookieDecoder) + subject := NewHandler(test.idpListGetter, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 3b5dcb60..afa06cc4 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -8,7 +8,10 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" + "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/oidcclient/nonce" + "go.pinniped.dev/internal/oidcclient/pkce" ) const ( @@ -19,6 +22,15 @@ const ( ) const ( + // Just in case we need to make a breaking change to the format of the upstream state param, + // we are including a format version number. This gives the opportunity for a future version of Pinniped + // to have the consumer of this format decide to reject versions that it doesn't understand. + UpstreamStateParamFormatVersion = "1" + + // The `name` passed to the encoder for encoding the upstream state param value. This name is short + // because it will be encoded into the upstream state param value and we're trying to keep that small. + UpstreamStateParamEncodingName = "s" + // CSRFCookieName is the name of the browser cookie which shall hold our CSRF value. // The `__Host` prefix has a special meaning. See // https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes. @@ -29,6 +41,36 @@ const ( CSRFCookieEncodingName = "csrf" ) +// Encoder is the encoding side of the securecookie.Codec interface. +type Encoder interface { + Encode(name string, value interface{}) (string, error) +} + +// Decoder is the decoding side of the securecookie.Codec interface. +type Decoder interface { + Decode(name, value string, into interface{}) error +} + +// Codec is both the encoding and decoding sides of the securecookie.Codec interface. It is +// interface'd here so that we properly wrap the securecookie dependency. +type Codec interface { + Encoder + Decoder +} + +// UpstreamStateParamData is the format of the state parameter that we use when we communicate to an +// upstream OIDC provider. +// +// Keep the JSON to a minimal size because the upstream provider could impose size limitations on +// the state param. +type UpstreamStateParamData struct { + AuthParams string `json:"p"` + Nonce nonce.Nonce `json:"n"` + CSRFToken csrftoken.CSRFToken `json:"c"` + PKCECode pkce.Code `json:"k"` + FormatVersion string `json:"v"` +} + func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { return &fosite.DefaultOpenIDConnectClient{ DefaultClient: &fosite.DefaultClient{