From b21f0035d7fd710539788f01b2727014941d7784 Mon Sep 17 00:00:00 2001 From: Andrew Keesler Date: Fri, 20 Nov 2020 13:33:08 -0800 Subject: [PATCH] callback_handler.go: Get upstream name from state instead of path Also use ConstantTimeCompare() to compare CSRF tokens to prevent leaking any information in how quickly we reject bad tokens. Signed-off-by: Ryan Richard --- internal/oidc/callback/callback_handler.go | 11 +++++------ internal/oidc/callback/callback_handler_test.go | 17 +++++------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 5b725623..f237726b 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,10 +5,10 @@ package callback import ( + "crypto/subtle" "fmt" "net/http" "net/url" - "path" "time" "github.com/ory/fosite" @@ -49,7 +49,7 @@ func NewHandler( return err } - upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter) + upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter) if upstreamIDPConfig == nil { plog.Warning("upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") @@ -137,7 +137,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return nil, err } - if state.CSRFToken != csrfValue { + if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 { plog.InfoErr("CSRF value does not match", err) return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) } @@ -145,10 +145,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) return state, nil } -func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { - _, lastPathComponent := path.Split(r.URL.Path) +func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { for _, p := range idpListGetter.GetIDPList() { - if p.GetName() == lastPathComponent { + if p.GetName() == upstreamName { return p } } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 5f1d16ba..c51ad86a 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -6,7 +6,6 @@ package callback import ( "context" "errors" - "fmt" "net/http" "net/http/httptest" "net/url" @@ -507,25 +506,18 @@ func TestCallbackEndpoint(t *testing.T) { } type requestPath struct { - upstreamIDPName, code, state *string + code, state *string } func newRequestPath() *requestPath { - n := happyUpstreamIDPName c := happyUpstreamAuthcode s := "4321" return &requestPath{ - upstreamIDPName: &n, - code: &c, - state: &s, + code: &c, + state: &s, } } -func (r *requestPath) WithUpstreamIDPName(name string) *requestPath { - r.upstreamIDPName = &name - return r -} - func (r *requestPath) WithCode(code string) *requestPath { r.code = &code return r @@ -547,7 +539,7 @@ func (r *requestPath) WithoutState() *requestPath { } func (r *requestPath) String() string { - path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName) + path := "/downstream-provider-name/callback?" params := url.Values{} if r.code != nil { params.Add("code", *r.code) @@ -562,6 +554,7 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat func happyUpstreamStateParam() *upstreamStateParamBuilder { return &upstreamStateParamBuilder{ + U: happyUpstreamIDPName, P: happyDownstreamRequestParams, N: happyDownstreamNonce, C: happyDownstreamCSRF,