WIP for saving authorize endpoint state into upstream state param
Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
005225d5f9
commit
dd190dede6
1
go.mod
1
go.mod
@ -14,6 +14,7 @@ require (
|
|||||||
github.com/golang/mock v1.4.4
|
github.com/golang/mock v1.4.4
|
||||||
github.com/golangci/golangci-lint v1.31.0
|
github.com/golangci/golangci-lint v1.31.0
|
||||||
github.com/google/go-cmp v0.5.2
|
github.com/google/go-cmp v0.5.2
|
||||||
|
github.com/gorilla/securecookie v1.1.1
|
||||||
github.com/ory/fosite v0.35.1
|
github.com/ory/fosite v0.35.1
|
||||||
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
|
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
|
||||||
github.com/sclevine/agouti v3.0.0+incompatible
|
github.com/sclevine/agouti v3.0.0+incompatible
|
||||||
|
2
go.sum
2
go.sum
@ -301,6 +301,8 @@ github.com/gookit/color v1.2.5/go.mod h1:AhIE+pS6D4Ql0SQWbBeXPHw7gY0/sjHoA4s/n1K
|
|||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
|
||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||||
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
|
||||||
|
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||||
|
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||||
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||||
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||||
|
@ -9,30 +9,42 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/plog"
|
|
||||||
|
|
||||||
"github.com/ory/fosite"
|
"github.com/ory/fosite"
|
||||||
"github.com/ory/fosite/handler/openid"
|
"github.com/ory/fosite/handler/openid"
|
||||||
"github.com/ory/fosite/token/jwt"
|
"github.com/ory/fosite/token/jwt"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/httputil/httperr"
|
"go.pinniped.dev/internal/httputil/httperr"
|
||||||
|
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||||
"go.pinniped.dev/internal/oidc/provider"
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
"go.pinniped.dev/internal/oidcclient/nonce"
|
"go.pinniped.dev/internal/oidcclient/nonce"
|
||||||
"go.pinniped.dev/internal/oidcclient/pkce"
|
"go.pinniped.dev/internal/oidcclient/pkce"
|
||||||
"go.pinniped.dev/internal/oidcclient/state"
|
"go.pinniped.dev/internal/plog"
|
||||||
"golang.org/x/oauth2"
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Just in case we need to make a breaking change to the format of the upstream state param,
|
||||||
|
// we are including a format version number. This gives the opportunity for a future version of Pinniped
|
||||||
|
// to have the consumer of this format decide to reject versions that it doesn't understand.
|
||||||
|
stateParamFormatVersion = "1"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IDPListGetter interface {
|
type IDPListGetter interface {
|
||||||
GetIDPList() []provider.UpstreamOIDCIdentityProvider
|
GetIDPList() []provider.UpstreamOIDCIdentityProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Encoder interface {
|
||||||
|
Encode(name string, value interface{}) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
func NewHandler(
|
func NewHandler(
|
||||||
issuer string,
|
issuer string,
|
||||||
idpListGetter IDPListGetter,
|
idpListGetter IDPListGetter,
|
||||||
oauthHelper fosite.OAuth2Provider,
|
oauthHelper fosite.OAuth2Provider,
|
||||||
generateState func() (state.State, error),
|
generateCSRF func() (csrftoken.CSRFToken, error),
|
||||||
generatePKCE func() (pkce.Code, error),
|
generatePKCE func() (pkce.Code, error),
|
||||||
generateNonce func() (nonce.Nonce, error),
|
generateNonce func() (nonce.Nonce, error),
|
||||||
|
encoder Encoder,
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
||||||
@ -77,7 +89,7 @@ func NewHandler(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
stateValue, nonceValue, pkceValue, err := generateParams(generateState, generateNonce, generatePKCE)
|
csrfValue, nonceValue, pkceValue, err := generateValues(generateCSRF, generateNonce, generatePKCE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
plog.InfoErr("authorize generate error", err)
|
plog.InfoErr("authorize generate error", err)
|
||||||
return err
|
return err
|
||||||
@ -92,9 +104,28 @@ func NewHandler(
|
|||||||
Scopes: upstreamIDP.Scopes,
|
Scopes: upstreamIDP.Scopes,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "__Host-pinniped-csrf",
|
||||||
|
Value: string(csrfValue),
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
Secure: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
stateParamData := upstreamStateParamData{
|
||||||
|
AuthParams: authorizeRequester.GetRequestForm().Encode(),
|
||||||
|
Nonce: nonceValue,
|
||||||
|
CSRFToken: csrfValue,
|
||||||
|
PKCECode: pkceValue,
|
||||||
|
StateParamFormatVersion: stateParamFormatVersion,
|
||||||
|
}
|
||||||
|
encodedStateParamValue, err := encoder.Encode("s", stateParamData)
|
||||||
|
// TODO handle the above error
|
||||||
|
|
||||||
http.Redirect(w, r,
|
http.Redirect(w, r,
|
||||||
upstreamOAuthConfig.AuthCodeURL(
|
upstreamOAuthConfig.AuthCodeURL(
|
||||||
stateValue.String(),
|
encodedStateParamValue,
|
||||||
oauth2.AccessTypeOffline,
|
oauth2.AccessTypeOffline,
|
||||||
nonceValue.Param(),
|
nonceValue.Param(),
|
||||||
pkceValue.Challenge(),
|
pkceValue.Challenge(),
|
||||||
@ -107,6 +138,15 @@ func NewHandler(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
|
||||||
|
type upstreamStateParamData struct {
|
||||||
|
AuthParams string `json:"p"`
|
||||||
|
Nonce nonce.Nonce `json:"n"`
|
||||||
|
CSRFToken csrftoken.CSRFToken `json:"c"`
|
||||||
|
PKCECode pkce.Code `json:"k"`
|
||||||
|
StateParamFormatVersion string `json:"v"`
|
||||||
|
}
|
||||||
|
|
||||||
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
|
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
|
||||||
allUpstreamIDPs := idpListGetter.GetIDPList()
|
allUpstreamIDPs := idpListGetter.GetIDPList()
|
||||||
if len(allUpstreamIDPs) == 0 {
|
if len(allUpstreamIDPs) == 0 {
|
||||||
@ -123,15 +163,15 @@ func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdent
|
|||||||
return &allUpstreamIDPs[0], nil
|
return &allUpstreamIDPs[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateParams(
|
func generateValues(
|
||||||
generateState func() (state.State, error),
|
generateCSRF func() (csrftoken.CSRFToken, error),
|
||||||
generateNonce func() (nonce.Nonce, error),
|
generateNonce func() (nonce.Nonce, error),
|
||||||
generatePKCE func() (pkce.Code, error),
|
generatePKCE func() (pkce.Code, error),
|
||||||
) (state.State, nonce.Nonce, pkce.Code, error) {
|
) (csrftoken.CSRFToken, nonce.Nonce, pkce.Code, error) {
|
||||||
stateValue, err := generateState()
|
csrfValue, err := generateCSRF()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
plog.InfoErr("error generating state param", err)
|
plog.InfoErr("error generating csrf param", err)
|
||||||
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating state param", err)
|
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating CSRF token", err)
|
||||||
}
|
}
|
||||||
nonceValue, err := generateNonce()
|
nonceValue, err := generateNonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -143,7 +183,7 @@ func generateParams(
|
|||||||
plog.InfoErr("error generating PKCE param", err)
|
plog.InfoErr("error generating PKCE param", err)
|
||||||
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err)
|
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err)
|
||||||
}
|
}
|
||||||
return stateValue, nonceValue, pkceValue, nil
|
return csrfValue, nonceValue, pkceValue, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fositeErrorForLog(err error) []interface{} {
|
func fositeErrorForLog(err error) []interface{} {
|
||||||
|
@ -13,16 +13,17 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
"github.com/ory/fosite"
|
"github.com/ory/fosite"
|
||||||
"github.com/ory/fosite/storage"
|
"github.com/ory/fosite/storage"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/here"
|
"go.pinniped.dev/internal/here"
|
||||||
"go.pinniped.dev/internal/oidc"
|
"go.pinniped.dev/internal/oidc"
|
||||||
|
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||||
"go.pinniped.dev/internal/oidc/provider"
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
"go.pinniped.dev/internal/oidcclient/nonce"
|
"go.pinniped.dev/internal/oidcclient/nonce"
|
||||||
"go.pinniped.dev/internal/oidcclient/pkce"
|
"go.pinniped.dev/internal/oidcclient/pkce"
|
||||||
"go.pinniped.dev/internal/oidcclient/state"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAuthorizationEndpoint(t *testing.T) {
|
func TestAuthorizationEndpoint(t *testing.T) {
|
||||||
@ -131,30 +132,37 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
|
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
|
||||||
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
|
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
|
||||||
|
|
||||||
happyStateGenerator := func() (state.State, error) { return "test-state", nil }
|
happyCSRF := "test-csrf"
|
||||||
happyPKCEGenerator := func() (pkce.Code, error) { return "test-pkce", nil }
|
happyPKCE := "test-pkce"
|
||||||
happyNonceGenerator := func() (nonce.Nonce, error) { return "test-nonce", nil }
|
happyNonce := "test-nonce"
|
||||||
|
happyCSRFGenerator := func() (csrftoken.CSRFToken, error) { return csrftoken.CSRFToken(happyCSRF), nil }
|
||||||
|
happyPKCEGenerator := func() (pkce.Code, error) { return pkce.Code(happyPKCE), nil }
|
||||||
|
happyNonceGenerator := func() (nonce.Nonce, error) { return nonce.Nonce(happyNonce), nil }
|
||||||
|
|
||||||
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
|
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
|
||||||
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
|
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
|
||||||
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
|
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
|
||||||
|
|
||||||
pathWithQuery := func(path string, query map[string]string) string {
|
var encoderHashKey = []byte("fake-hash-secret")
|
||||||
|
var encoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption
|
||||||
|
encoder.SetSerializer(securecookie.JSONEncoder{})
|
||||||
|
|
||||||
|
encodeQuery := func(query map[string]string) string {
|
||||||
values := url.Values{}
|
values := url.Values{}
|
||||||
for k, v := range query {
|
for k, v := range query {
|
||||||
values[k] = []string{v}
|
values[k] = []string{v}
|
||||||
}
|
}
|
||||||
pathToReturn := fmt.Sprintf("%s?%s", path, values.Encode())
|
return values.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
pathWithQuery := func(path string, query map[string]string) string {
|
||||||
|
pathToReturn := fmt.Sprintf("%s?%s", path, encodeQuery(query))
|
||||||
require.NotRegexp(t, "^http", pathToReturn, "pathWithQuery helper was used to create a URL")
|
require.NotRegexp(t, "^http", pathToReturn, "pathWithQuery helper was used to create a URL")
|
||||||
return pathToReturn
|
return pathToReturn
|
||||||
}
|
}
|
||||||
|
|
||||||
urlWithQuery := func(baseURL string, query map[string]string) string {
|
urlWithQuery := func(baseURL string, query map[string]string) string {
|
||||||
values := url.Values{}
|
urlToReturn := fmt.Sprintf("%s?%s", baseURL, encodeQuery(query))
|
||||||
for k, v := range query {
|
|
||||||
values[k] = []string{v}
|
|
||||||
}
|
|
||||||
urlToReturn := fmt.Sprintf("%s?%s", baseURL, values.Encode())
|
|
||||||
_, err := url.Parse(urlToReturn)
|
_, err := url.Parse(urlToReturn)
|
||||||
require.NoError(t, err, "urlWithQuery helper was used to create an illegal URL")
|
require.NoError(t, err, "urlWithQuery helper was used to create an illegal URL")
|
||||||
return urlToReturn
|
return urlToReturn
|
||||||
@ -189,26 +197,47 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap)
|
return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We're going to use this value to make assertions, so specify the exact expected value.
|
||||||
|
happyUpstreamStateParam, err := encoder.Encode("s",
|
||||||
|
// Ensure that the order of the serialized fields is exactly this order, so we can make simpler equality assertions below.
|
||||||
|
struct {
|
||||||
|
P string `json:"p"`
|
||||||
|
N string `json:"n"`
|
||||||
|
C string `json:"c"`
|
||||||
|
K string `json:"k"`
|
||||||
|
V string `json:"v"`
|
||||||
|
}{
|
||||||
|
P: encodeQuery(happyGetRequestQueryMap),
|
||||||
|
N: happyNonce,
|
||||||
|
C: happyCSRF,
|
||||||
|
K: happyPKCE,
|
||||||
|
V: "1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(),
|
happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(),
|
||||||
map[string]string{
|
map[string]string{
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"access_type": "offline",
|
"access_type": "offline",
|
||||||
"scope": "scope1 scope2",
|
"scope": "scope1 scope2",
|
||||||
"client_id": "some-client-id",
|
"client_id": "some-client-id",
|
||||||
"state": "test-state",
|
"state": happyUpstreamStateParam,
|
||||||
"nonce": "test-nonce",
|
"nonce": happyNonce,
|
||||||
"code_challenge": expectedUpstreamCodeChallenge,
|
"code_challenge": expectedUpstreamCodeChallenge,
|
||||||
"code_challenge_method": "S256",
|
"code_challenge_method": "S256",
|
||||||
"redirect_uri": issuer + "/callback/some-idp",
|
"redirect_uri": issuer + "/callback/some-idp",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
issuer string
|
issuer string
|
||||||
idpListGetter provider.DynamicUpstreamIDPProvider
|
idpListGetter provider.DynamicUpstreamIDPProvider
|
||||||
generateState func() (state.State, error)
|
generateCSRF func() (csrftoken.CSRFToken, error)
|
||||||
generatePKCE func() (pkce.Code, error)
|
generatePKCE func() (pkce.Code, error)
|
||||||
generateNonce func() (nonce.Nonce, error)
|
generateNonce func() (nonce.Nonce, error)
|
||||||
method string
|
method string
|
||||||
@ -221,6 +250,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
wantBodyString string
|
wantBodyString string
|
||||||
wantBodyJSON string
|
wantBodyJSON string
|
||||||
wantLocationHeader string
|
wantLocationHeader string
|
||||||
|
wantCSRFCookieHeader string
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []testCase{
|
tests := []testCase{
|
||||||
@ -228,7 +258,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "happy path using GET",
|
name: "happy path using GET",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -240,12 +270,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
"\n\n",
|
"\n\n",
|
||||||
),
|
),
|
||||||
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
|
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
|
||||||
|
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "happy path using POST",
|
name: "happy path using POST",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodPost,
|
method: http.MethodPost,
|
||||||
@ -264,12 +295,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
wantContentType: "",
|
wantContentType: "",
|
||||||
wantBodyString: "",
|
wantBodyString: "",
|
||||||
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
|
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
|
||||||
|
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "downstream client does not exist",
|
name: "downstream client does not exist",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -282,7 +314,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "downstream redirect uri does not match what is configured for client",
|
name: "downstream redirect uri does not match what is configured for client",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -294,10 +326,10 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
wantBodyJSON: fositeInvalidRedirectURIErrorBody,
|
wantBodyJSON: fositeInvalidRedirectURIErrorBody,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "downstream redirect uri matches what is configured for client except for the port number",
|
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -311,12 +343,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
"\n\n",
|
"\n\n",
|
||||||
),
|
),
|
||||||
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
|
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
|
||||||
|
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "response type is unsupported",
|
name: "response type is unsupported",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -330,7 +363,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "downstream scopes do not match what is configured for client",
|
name: "downstream scopes do not match what is configured for client",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -344,7 +377,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "missing response type in request",
|
name: "missing response type in request",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -358,7 +391,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "missing client id in request",
|
name: "missing client id in request",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -371,7 +404,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -385,7 +418,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
|
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -399,7 +432,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
|
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -413,7 +446,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -429,7 +462,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "prompt param is not allowed to have none and another legal value at the same time",
|
name: "prompt param is not allowed to have none and another legal value at the same time",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -443,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "state does not have enough entropy",
|
name: "state does not have enough entropy",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -454,23 +487,23 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
wantBodyString: "",
|
wantBodyString: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error while generating state",
|
name: "error while generating CSRF token",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: func() (state.State, error) { return "", fmt.Errorf("some state generator error") },
|
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: happyGetRequestPath,
|
path: happyGetRequestPath,
|
||||||
wantStatus: http.StatusInternalServerError,
|
wantStatus: http.StatusInternalServerError,
|
||||||
wantContentType: "text/plain; charset=utf-8",
|
wantContentType: "text/plain; charset=utf-8",
|
||||||
wantBodyString: "Internal Server Error: error generating state param\n",
|
wantBodyString: "Internal Server Error: error generating CSRF token\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error while generating nonce",
|
name: "error while generating nonce",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
|
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -483,7 +516,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
name: "error while generating PKCE",
|
name: "error while generating PKCE",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||||
generateState: happyStateGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
|
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
@ -562,13 +595,23 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
if test.wantLocationHeader != "" {
|
if test.wantLocationHeader != "" {
|
||||||
actualLocation := rsp.Header().Get("Location")
|
actualLocation := rsp.Header().Get("Location")
|
||||||
requireEqualURLs(t, actualLocation, test.wantLocationHeader)
|
requireEqualURLs(t, actualLocation, test.wantLocationHeader)
|
||||||
|
} else {
|
||||||
|
require.Empty(t, rsp.Header().Values("Location"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.wantCSRFCookieHeader != "" {
|
||||||
|
require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
|
||||||
|
actualCookie := rsp.Header().Get("Set-Cookie")
|
||||||
|
require.Equal(t, actualCookie, test.wantCSRFCookieHeader)
|
||||||
|
} else {
|
||||||
|
require.Empty(t, rsp.Header().Values("Set-Cookie"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateState, test.generatePKCE, test.generateNonce)
|
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder)
|
||||||
runOneTestCase(t, test, subject)
|
runOneTestCase(t, test, subject)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -577,7 +620,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
test := tests[0]
|
test := tests[0]
|
||||||
require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case
|
require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case
|
||||||
|
|
||||||
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateState, test.generatePKCE, test.generateNonce)
|
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder)
|
||||||
|
|
||||||
runOneTestCase(t, test, subject)
|
runOneTestCase(t, test, subject)
|
||||||
|
|
||||||
@ -597,8 +640,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
"access_type": "offline",
|
"access_type": "offline",
|
||||||
"scope": "other-scope1 other-scope2",
|
"scope": "other-scope1 other-scope2",
|
||||||
"client_id": "some-other-client-id",
|
"client_id": "some-other-client-id",
|
||||||
"state": "test-state",
|
"state": happyUpstreamStateParam,
|
||||||
"nonce": "test-nonce",
|
"nonce": happyNonce,
|
||||||
"code_challenge": expectedUpstreamCodeChallenge,
|
"code_challenge": expectedUpstreamCodeChallenge,
|
||||||
"code_challenge_method": "S256",
|
"code_challenge_method": "S256",
|
||||||
"redirect_uri": issuer + "/callback/some-other-idp",
|
"redirect_uri": issuer + "/callback/some-other-idp",
|
||||||
|
24
internal/oidc/csrftoken/csrftoken.go
Normal file
24
internal/oidc/csrftoken/csrftoken.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
package csrftoken
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate generates a new random CSRF token value.
|
||||||
|
func Generate() (CSRFToken, error) { return generate(rand.Reader) }
|
||||||
|
|
||||||
|
func generate(rand io.Reader) (CSRFToken, error) {
|
||||||
|
var buf [32]byte
|
||||||
|
if _, err := io.ReadFull(rand, buf[:]); err != nil {
|
||||||
|
return "", fmt.Errorf("could not generate CSRFToken: %w", err)
|
||||||
|
}
|
||||||
|
return CSRFToken(hex.EncodeToString(buf[:])), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type CSRFToken string
|
@ -8,6 +8,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
|
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||||
|
|
||||||
"github.com/ory/fosite"
|
"github.com/ory/fosite"
|
||||||
"github.com/ory/fosite/storage"
|
"github.com/ory/fosite/storage"
|
||||||
|
|
||||||
@ -18,7 +21,6 @@ import (
|
|||||||
"go.pinniped.dev/internal/oidc/provider"
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
"go.pinniped.dev/internal/oidcclient/nonce"
|
"go.pinniped.dev/internal/oidcclient/nonce"
|
||||||
"go.pinniped.dev/internal/oidcclient/pkce"
|
"go.pinniped.dev/internal/oidcclient/pkce"
|
||||||
"go.pinniped.dev/internal/oidcclient/state"
|
|
||||||
"go.pinniped.dev/internal/plog"
|
"go.pinniped.dev/internal/plog"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,8 +80,13 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
|
|||||||
}
|
}
|
||||||
oauthHelper := oidc.FositeOauth2Helper(oauthStore, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret
|
oauthHelper := oidc.FositeOauth2Helper(oauthStore, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret
|
||||||
|
|
||||||
|
var encoderHashKey = []byte("fake-hash-secret") // TODO fix this
|
||||||
|
var encoderBlockKey = []byte("16-bytes-aaaaaaa") // TODO fix this
|
||||||
|
var encoder = securecookie.New(encoderHashKey, encoderBlockKey)
|
||||||
|
encoder.SetSerializer(securecookie.JSONEncoder{})
|
||||||
|
|
||||||
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
|
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
|
||||||
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, state.Generate, pkce.Generate, nonce.Generate)
|
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder)
|
||||||
|
|
||||||
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
|
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user