callback_handler.go: Get upstream name from state instead of path

Also use ConstantTimeCompare() to compare CSRF tokens to prevent
leaking any information in how quickly we reject bad tokens.

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-20 13:33:08 -08:00 committed by Ryan Richard
parent 72321fc106
commit b21f0035d7
2 changed files with 10 additions and 18 deletions

View File

@ -5,10 +5,10 @@
package callback
import (
"crypto/subtle"
"fmt"
"net/http"
"net/url"
"path"
"time"
"github.com/ory/fosite"
@ -49,7 +49,7 @@ func NewHandler(
return err
}
upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter)
upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter)
if upstreamIDPConfig == nil {
plog.Warning("upstream provider not found")
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
@ -137,7 +137,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return nil, err
}
if state.CSRFToken != csrfValue {
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
plog.InfoErr("CSRF value does not match", err)
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
}
@ -145,10 +145,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return state, nil
}
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
_, lastPathComponent := path.Split(r.URL.Path)
func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
for _, p := range idpListGetter.GetIDPList() {
if p.GetName() == lastPathComponent {
if p.GetName() == upstreamName {
return p
}
}

View File

@ -6,7 +6,6 @@ package callback
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
@ -507,25 +506,18 @@ func TestCallbackEndpoint(t *testing.T) {
}
type requestPath struct {
upstreamIDPName, code, state *string
code, state *string
}
func newRequestPath() *requestPath {
n := happyUpstreamIDPName
c := happyUpstreamAuthcode
s := "4321"
return &requestPath{
upstreamIDPName: &n,
code: &c,
state: &s,
code: &c,
state: &s,
}
}
func (r *requestPath) WithUpstreamIDPName(name string) *requestPath {
r.upstreamIDPName = &name
return r
}
func (r *requestPath) WithCode(code string) *requestPath {
r.code = &code
return r
@ -547,7 +539,7 @@ func (r *requestPath) WithoutState() *requestPath {
}
func (r *requestPath) String() string {
path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName)
path := "/downstream-provider-name/callback?"
params := url.Values{}
if r.code != nil {
params.Add("code", *r.code)
@ -562,6 +554,7 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
func happyUpstreamStateParam() *upstreamStateParamBuilder {
return &upstreamStateParamBuilder{
U: happyUpstreamIDPName,
P: happyDownstreamRequestParams,
N: happyDownstreamNonce,
C: happyDownstreamCSRF,