callback_handler.go: Get upstream name from state instead of path
Also use ConstantTimeCompare() to compare CSRF tokens to prevent leaking any information in how quickly we reject bad tokens. Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
72321fc106
commit
b21f0035d7
@ -5,10 +5,10 @@
|
||||
package callback
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
@ -49,7 +49,7 @@ func NewHandler(
|
||||
return err
|
||||
}
|
||||
|
||||
upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter)
|
||||
upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter)
|
||||
if upstreamIDPConfig == nil {
|
||||
plog.Warning("upstream provider not found")
|
||||
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
|
||||
@ -137,7 +137,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if state.CSRFToken != csrfValue {
|
||||
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
|
||||
plog.InfoErr("CSRF value does not match", err)
|
||||
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
|
||||
}
|
||||
@ -145,10 +145,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
|
||||
_, lastPathComponent := path.Split(r.URL.Path)
|
||||
func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
|
||||
for _, p := range idpListGetter.GetIDPList() {
|
||||
if p.GetName() == lastPathComponent {
|
||||
if p.GetName() == upstreamName {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ package callback
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@ -507,25 +506,18 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
}
|
||||
|
||||
type requestPath struct {
|
||||
upstreamIDPName, code, state *string
|
||||
code, state *string
|
||||
}
|
||||
|
||||
func newRequestPath() *requestPath {
|
||||
n := happyUpstreamIDPName
|
||||
c := happyUpstreamAuthcode
|
||||
s := "4321"
|
||||
return &requestPath{
|
||||
upstreamIDPName: &n,
|
||||
code: &c,
|
||||
state: &s,
|
||||
code: &c,
|
||||
state: &s,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *requestPath) WithUpstreamIDPName(name string) *requestPath {
|
||||
r.upstreamIDPName = &name
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithCode(code string) *requestPath {
|
||||
r.code = &code
|
||||
return r
|
||||
@ -547,7 +539,7 @@ func (r *requestPath) WithoutState() *requestPath {
|
||||
}
|
||||
|
||||
func (r *requestPath) String() string {
|
||||
path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName)
|
||||
path := "/downstream-provider-name/callback?"
|
||||
params := url.Values{}
|
||||
if r.code != nil {
|
||||
params.Add("code", *r.code)
|
||||
@ -562,6 +554,7 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
|
||||
|
||||
func happyUpstreamStateParam() *upstreamStateParamBuilder {
|
||||
return &upstreamStateParamBuilder{
|
||||
U: happyUpstreamIDPName,
|
||||
P: happyDownstreamRequestParams,
|
||||
N: happyDownstreamNonce,
|
||||
C: happyDownstreamCSRF,
|
||||
|
Loading…
Reference in New Issue
Block a user