callback_handler.go: Get upstream name from state instead of path

Also use ConstantTimeCompare() to compare CSRF tokens to prevent
leaking any information in how quickly we reject bad tokens.

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-20 13:33:08 -08:00 committed by Ryan Richard
parent 72321fc106
commit b21f0035d7
2 changed files with 10 additions and 18 deletions

View File

@ -5,10 +5,10 @@
package callback package callback
import ( import (
"crypto/subtle"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"path"
"time" "time"
"github.com/ory/fosite" "github.com/ory/fosite"
@ -49,7 +49,7 @@ func NewHandler(
return err return err
} }
upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter) upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter)
if upstreamIDPConfig == nil { if upstreamIDPConfig == nil {
plog.Warning("upstream provider not found") plog.Warning("upstream provider not found")
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
@ -137,7 +137,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return nil, err return nil, err
} }
if state.CSRFToken != csrfValue { if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
plog.InfoErr("CSRF value does not match", err) plog.InfoErr("CSRF value does not match", err)
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
} }
@ -145,10 +145,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return state, nil return state, nil
} }
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI { func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
_, lastPathComponent := path.Split(r.URL.Path)
for _, p := range idpListGetter.GetIDPList() { for _, p := range idpListGetter.GetIDPList() {
if p.GetName() == lastPathComponent { if p.GetName() == upstreamName {
return p return p
} }
} }

View File

@ -6,7 +6,6 @@ package callback
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -507,25 +506,18 @@ func TestCallbackEndpoint(t *testing.T) {
} }
type requestPath struct { type requestPath struct {
upstreamIDPName, code, state *string code, state *string
} }
func newRequestPath() *requestPath { func newRequestPath() *requestPath {
n := happyUpstreamIDPName
c := happyUpstreamAuthcode c := happyUpstreamAuthcode
s := "4321" s := "4321"
return &requestPath{ return &requestPath{
upstreamIDPName: &n, code: &c,
code: &c, state: &s,
state: &s,
} }
} }
func (r *requestPath) WithUpstreamIDPName(name string) *requestPath {
r.upstreamIDPName = &name
return r
}
func (r *requestPath) WithCode(code string) *requestPath { func (r *requestPath) WithCode(code string) *requestPath {
r.code = &code r.code = &code
return r return r
@ -547,7 +539,7 @@ func (r *requestPath) WithoutState() *requestPath {
} }
func (r *requestPath) String() string { func (r *requestPath) String() string {
path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName) path := "/downstream-provider-name/callback?"
params := url.Values{} params := url.Values{}
if r.code != nil { if r.code != nil {
params.Add("code", *r.code) params.Add("code", *r.code)
@ -562,6 +554,7 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
func happyUpstreamStateParam() *upstreamStateParamBuilder { func happyUpstreamStateParam() *upstreamStateParamBuilder {
return &upstreamStateParamBuilder{ return &upstreamStateParamBuilder{
U: happyUpstreamIDPName,
P: happyDownstreamRequestParams, P: happyDownstreamRequestParams,
N: happyDownstreamNonce, N: happyDownstreamNonce,
C: happyDownstreamCSRF, C: happyDownstreamCSRF,