callback_handler.go: write 2 invalid cookie tests

Also common-ize some more constants shared between the auth and callback
endpoints.

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-16 11:47:49 -05:00
parent 3ef1171667
commit 4138c9244f
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
4 changed files with 108 additions and 41 deletions

View File

@ -34,16 +34,9 @@ const (
// The `name` passed to the encoder for encoding the upstream state param value. This name is short // The `name` passed to the encoder for encoding the upstream state param value. This name is short
// because it will be encoded into the upstream state param value and we're trying to keep that small. // because it will be encoded into the upstream state param value and we're trying to keep that small.
upstreamStateParamEncodingName = "s" upstreamStateParamEncodingName = "s"
// The name of the browser cookie which shall hold our CSRF value.
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
csrfCookieName = "__Host-pinniped-csrf"
// The `name` passed to the encoder for encoding and decoding the CSRF cookie contents.
csrfCookieEncodingName = "csrf"
) )
// This is the encoding side of the securecookie.Codec interface. // Encoder is the encoding side of the securecookie.Codec interface.
type Encoder interface { type Encoder interface {
Encode(name string, value interface{}) (string, error) Encode(name string, value interface{}) (string, error)
} }
@ -152,14 +145,14 @@ func NewHandler(
} }
func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(csrfCookieName) receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
if err != nil { if err != nil {
// Error means that the cookie was not found // Error means that the cookie was not found
return "", nil return "", nil
} }
var csrfFromCookie csrftoken.CSRFToken var csrfFromCookie csrftoken.CSRFToken
err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) err = codec.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil { if err != nil {
return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err) return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err)
} }
@ -242,13 +235,13 @@ func upstreamStateParam(
} }
func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error {
encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue) encodedCSRFValue, err := codec.Encode(oidc.CSRFCookieEncodingName, csrfValue)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err)
} }
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: csrfCookieName, Name: oidc.CSRFCookieName,
Value: encodedCSRFValue, Value: encodedCSRFValue,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,

View File

@ -10,17 +10,29 @@ import (
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
) )
// Decoder is the decoding side of the securecookie.Codec interface.
type Decoder interface {
Decode(name, value string, into interface{}) error
}
func NewHandler( func NewHandler(
idpListGetter oidc.IDPListGetter, idpListGetter oidc.IDPListGetter,
cookieDecoder Decoder,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
} }
_, err := readCSRFCookie(r, cookieDecoder)
if err != nil {
return err
}
if r.FormValue("code") == "" { if r.FormValue("code") == "" {
return httperr.New(http.StatusBadRequest, "code param not found") return httperr.New(http.StatusBadRequest, "code param not found")
} }
@ -46,3 +58,19 @@ func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *p
} }
return nil return nil
} }
func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
if err != nil {
// Error means that the cookie was not found
return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err)
}
var csrfFromCookie csrftoken.CSRFToken
err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil {
return "", httperr.Wrap(http.StatusForbidden, "unauthorized request", err)
}
return csrfFromCookie, nil
}

View File

@ -68,12 +68,19 @@ func TestCallbackEndpoint(t *testing.T) {
// ) // )
// require.NoError(t, err) // require.NoError(t, err)
incomingCookieCSRFValue := "csrf-value-from-cookie"
encodedIncomingCookieCSRFValue, err := happyCookieEncoder.Encode("csrf", incomingCookieCSRFValue)
require.NoError(t, err)
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
tests := []struct { tests := []struct {
name string name string
idpListGetter provider.DynamicUpstreamIDPProvider
cookieDecoder Decoder
method string method string
path string path string
idpListGetter provider.DynamicUpstreamIDPProvider csrfCookie string
wantStatus int wantStatus int
wantBody string wantBody string
@ -112,36 +119,61 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "code param was not included on request", name: "code param was not included on request",
cookieDecoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithoutCode().String(), path: newRequestPath().WithoutCode().String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: code param not found\n", wantBody: "Bad Request: code param not found\n",
}, },
{ {
name: "state param was not included on request", name: "state param was not included on request",
cookieDecoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithoutState().String(), path: newRequestPath().WithoutState().String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: state param not found\n", wantBody: "Bad Request: state param not found\n",
}, },
{ {
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
cookieDecoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState("this-will-not-decode").String(), path: newRequestPath().WithState("this-will-not-decode").String(),
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: state param not valid\n", wantBody: "Bad Request: state param not valid\n",
}, },
{ {
name: "the UpstreamOIDCProvider CRD has been deleted", name: "the UpstreamOIDCProvider CRD has been deleted",
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider),
cookieDecoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().String(), path: newRequestPath().String(),
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity, wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: upstream provider not found\n", wantBody: "Unprocessable Entity: upstream provider not found\n",
}, },
// TODO: csrf cookie does not exist on request {
// TODO: csrf cookie value cannot be decoded (e.g. invalid signture or any other decoding problem) name: "the CSRF cookie does not exist on request",
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider),
cookieDecoder: happyCookieEncoder,
method: http.MethodGet,
path: newRequestPath().String(),
wantStatus: http.StatusForbidden,
wantBody: "Forbidden: unauthorized request\n",
},
{
name: "the CSRF cookie cannot be decoded",
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider),
cookieDecoder: happyCookieEncoder,
method: http.MethodGet,
path: newRequestPath().String(),
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusForbidden,
wantBody: "Forbidden: unauthorized request\n",
},
// TODO: csrf value from inside state param does not match csrf cookie value // TODO: csrf value from inside state param does not match csrf cookie value
// TODO: state's internal version does not match what we want // TODO: state's internal version does not match what we want
@ -169,8 +201,11 @@ func TestCallbackEndpoint(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject := NewHandler(test.idpListGetter) subject := NewHandler(test.idpListGetter, test.cookieDecoder)
req := httptest.NewRequest(test.method, test.path, nil) req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie)
}
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)

View File

@ -18,6 +18,17 @@ const (
JWKSEndpointPath = "/jwks.json" JWKSEndpointPath = "/jwks.json"
) )
const (
// CSRFCookieName is the name of the browser cookie which shall hold our CSRF value.
// The `__Host` prefix has a special meaning. See
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes.
CSRFCookieName = "__Host-pinniped-csrf"
// CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF
// cookie contents.
CSRFCookieEncodingName = "csrf"
)
func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient { func PinnipedCLIOIDCClient() *fosite.DefaultOpenIDConnectClient {
return &fosite.DefaultOpenIDConnectClient{ return &fosite.DefaultOpenIDConnectClient{
DefaultClient: &fosite.DefaultClient{ DefaultClient: &fosite.DefaultClient{