callback_handler.go: Get upstream name from state instead of path
Also use ConstantTimeCompare() to compare CSRF tokens to prevent leaking any information in how quickly we reject bad tokens. Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
72321fc106
commit
b21f0035d7
@ -5,10 +5,10 @@
|
|||||||
package callback
|
package callback
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/subtle"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ory/fosite"
|
"github.com/ory/fosite"
|
||||||
@ -49,7 +49,7 @@ func NewHandler(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter)
|
upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, idpListGetter)
|
||||||
if upstreamIDPConfig == nil {
|
if upstreamIDPConfig == nil {
|
||||||
plog.Warning("upstream provider not found")
|
plog.Warning("upstream provider not found")
|
||||||
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
|
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
|
||||||
@ -137,7 +137,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if state.CSRFToken != csrfValue {
|
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
|
||||||
plog.InfoErr("CSRF value does not match", err)
|
plog.InfoErr("CSRF value does not match", err)
|
||||||
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
|
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
|
||||||
}
|
}
|
||||||
@ -145,10 +145,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
|
|||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
|
func findUpstreamIDPConfig(upstreamName string, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
|
||||||
_, lastPathComponent := path.Split(r.URL.Path)
|
|
||||||
for _, p := range idpListGetter.GetIDPList() {
|
for _, p := range idpListGetter.GetIDPList() {
|
||||||
if p.GetName() == lastPathComponent {
|
if p.GetName() == upstreamName {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,6 @@ package callback
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -507,25 +506,18 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type requestPath struct {
|
type requestPath struct {
|
||||||
upstreamIDPName, code, state *string
|
code, state *string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRequestPath() *requestPath {
|
func newRequestPath() *requestPath {
|
||||||
n := happyUpstreamIDPName
|
|
||||||
c := happyUpstreamAuthcode
|
c := happyUpstreamAuthcode
|
||||||
s := "4321"
|
s := "4321"
|
||||||
return &requestPath{
|
return &requestPath{
|
||||||
upstreamIDPName: &n,
|
code: &c,
|
||||||
code: &c,
|
state: &s,
|
||||||
state: &s,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *requestPath) WithUpstreamIDPName(name string) *requestPath {
|
|
||||||
r.upstreamIDPName = &name
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *requestPath) WithCode(code string) *requestPath {
|
func (r *requestPath) WithCode(code string) *requestPath {
|
||||||
r.code = &code
|
r.code = &code
|
||||||
return r
|
return r
|
||||||
@ -547,7 +539,7 @@ func (r *requestPath) WithoutState() *requestPath {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *requestPath) String() string {
|
func (r *requestPath) String() string {
|
||||||
path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName)
|
path := "/downstream-provider-name/callback?"
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
if r.code != nil {
|
if r.code != nil {
|
||||||
params.Add("code", *r.code)
|
params.Add("code", *r.code)
|
||||||
@ -562,6 +554,7 @@ type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
|
|||||||
|
|
||||||
func happyUpstreamStateParam() *upstreamStateParamBuilder {
|
func happyUpstreamStateParam() *upstreamStateParamBuilder {
|
||||||
return &upstreamStateParamBuilder{
|
return &upstreamStateParamBuilder{
|
||||||
|
U: happyUpstreamIDPName,
|
||||||
P: happyDownstreamRequestParams,
|
P: happyDownstreamRequestParams,
|
||||||
N: happyDownstreamNonce,
|
N: happyDownstreamNonce,
|
||||||
C: happyDownstreamCSRF,
|
C: happyDownstreamCSRF,
|
||||||
|
Loading…
Reference in New Issue
Block a user