WIP for saving authorize endpoint state into upstream state param

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Monis Khan 2020-11-10 17:58:00 -08:00 committed by Ryan Richard
parent 005225d5f9
commit dd190dede6
6 changed files with 183 additions and 66 deletions

1
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/golang/mock v1.4.4 github.com/golang/mock v1.4.4
github.com/golangci/golangci-lint v1.31.0 github.com/golangci/golangci-lint v1.31.0
github.com/google/go-cmp v0.5.2 github.com/google/go-cmp v0.5.2
github.com/gorilla/securecookie v1.1.1
github.com/ory/fosite v0.35.1 github.com/ory/fosite v0.35.1
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/sclevine/agouti v3.0.0+incompatible github.com/sclevine/agouti v3.0.0+incompatible

2
go.sum
View File

@ -301,6 +301,8 @@ github.com/gookit/color v1.2.5/go.mod h1:AhIE+pS6D4Ql0SQWbBeXPHw7gY0/sjHoA4s/n1K
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8=
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v0.0.0-20170926233335-4201258b820c/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=

View File

@ -9,30 +9,42 @@ import (
"net/http" "net/http"
"time" "time"
"go.pinniped.dev/internal/plog"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt" "github.com/ory/fosite/token/jwt"
"golang.org/x/oauth2"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/oidcclient/state" "go.pinniped.dev/internal/plog"
"golang.org/x/oauth2" )
const (
// Just in case we need to make a breaking change to the format of the upstream state param,
// we are including a format version number. This gives the opportunity for a future version of Pinniped
// to have the consumer of this format decide to reject versions that it doesn't understand.
stateParamFormatVersion = "1"
) )
type IDPListGetter interface { type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider GetIDPList() []provider.UpstreamOIDCIdentityProvider
} }
type Encoder interface {
Encode(name string, value interface{}) (string, error)
}
func NewHandler( func NewHandler(
issuer string, issuer string,
idpListGetter IDPListGetter, idpListGetter IDPListGetter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
generateState func() (state.State, error), generateCSRF func() (csrftoken.CSRFToken, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
encoder Encoder,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost && r.Method != http.MethodGet { if r.Method != http.MethodPost && r.Method != http.MethodGet {
@ -77,7 +89,7 @@ func NewHandler(
return nil return nil
} }
stateValue, nonceValue, pkceValue, err := generateParams(generateState, generateNonce, generatePKCE) csrfValue, nonceValue, pkceValue, err := generateValues(generateCSRF, generateNonce, generatePKCE)
if err != nil { if err != nil {
plog.InfoErr("authorize generate error", err) plog.InfoErr("authorize generate error", err)
return err return err
@ -92,9 +104,28 @@ func NewHandler(
Scopes: upstreamIDP.Scopes, Scopes: upstreamIDP.Scopes,
} }
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
http.SetCookie(w, &http.Cookie{
Name: "__Host-pinniped-csrf",
Value: string(csrfValue),
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: true,
})
stateParamData := upstreamStateParamData{
AuthParams: authorizeRequester.GetRequestForm().Encode(),
Nonce: nonceValue,
CSRFToken: csrfValue,
PKCECode: pkceValue,
StateParamFormatVersion: stateParamFormatVersion,
}
encodedStateParamValue, err := encoder.Encode("s", stateParamData)
// TODO handle the above error
http.Redirect(w, r, http.Redirect(w, r,
upstreamOAuthConfig.AuthCodeURL( upstreamOAuthConfig.AuthCodeURL(
stateValue.String(), encodedStateParamValue,
oauth2.AccessTypeOffline, oauth2.AccessTypeOffline,
nonceValue.Param(), nonceValue.Param(),
pkceValue.Challenge(), pkceValue.Challenge(),
@ -107,6 +138,15 @@ func NewHandler(
}) })
} }
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
type upstreamStateParamData struct {
AuthParams string `json:"p"`
Nonce nonce.Nonce `json:"n"`
CSRFToken csrftoken.CSRFToken `json:"c"`
PKCECode pkce.Code `json:"k"`
StateParamFormatVersion string `json:"v"`
}
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
allUpstreamIDPs := idpListGetter.GetIDPList() allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 { if len(allUpstreamIDPs) == 0 {
@ -123,15 +163,15 @@ func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdent
return &allUpstreamIDPs[0], nil return &allUpstreamIDPs[0], nil
} }
func generateParams( func generateValues(
generateState func() (state.State, error), generateCSRF func() (csrftoken.CSRFToken, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
) (state.State, nonce.Nonce, pkce.Code, error) { ) (csrftoken.CSRFToken, nonce.Nonce, pkce.Code, error) {
stateValue, err := generateState() csrfValue, err := generateCSRF()
if err != nil { if err != nil {
plog.InfoErr("error generating state param", err) plog.InfoErr("error generating csrf param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating state param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating CSRF token", err)
} }
nonceValue, err := generateNonce() nonceValue, err := generateNonce()
if err != nil { if err != nil {
@ -143,7 +183,7 @@ func generateParams(
plog.InfoErr("error generating PKCE param", err) plog.InfoErr("error generating PKCE param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err)
} }
return stateValue, nonceValue, pkceValue, nil return csrfValue, nonceValue, pkceValue, nil
} }
func fositeErrorForLog(err error) []interface{} { func fositeErrorForLog(err error) []interface{} {

View File

@ -13,16 +13,17 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/gorilla/securecookie"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/storage" "github.com/ory/fosite/storage"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/oidcclient/state"
) )
func TestAuthorizationEndpoint(t *testing.T) { func TestAuthorizationEndpoint(t *testing.T) {
@ -131,30 +132,37 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
happyStateGenerator := func() (state.State, error) { return "test-state", nil } happyCSRF := "test-csrf"
happyPKCEGenerator := func() (pkce.Code, error) { return "test-pkce", nil } happyPKCE := "test-pkce"
happyNonceGenerator := func() (nonce.Nonce, error) { return "test-nonce", nil } happyNonce := "test-nonce"
happyCSRFGenerator := func() (csrftoken.CSRFToken, error) { return csrftoken.CSRFToken(happyCSRF), nil }
happyPKCEGenerator := func() (pkce.Code, error) { return pkce.Code(happyPKCE), nil }
happyNonceGenerator := func() (nonce.Nonce, error) { return nonce.Nonce(happyNonce), nil }
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g" expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
pathWithQuery := func(path string, query map[string]string) string { var encoderHashKey = []byte("fake-hash-secret")
var encoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption
encoder.SetSerializer(securecookie.JSONEncoder{})
encodeQuery := func(query map[string]string) string {
values := url.Values{} values := url.Values{}
for k, v := range query { for k, v := range query {
values[k] = []string{v} values[k] = []string{v}
} }
pathToReturn := fmt.Sprintf("%s?%s", path, values.Encode()) return values.Encode()
}
pathWithQuery := func(path string, query map[string]string) string {
pathToReturn := fmt.Sprintf("%s?%s", path, encodeQuery(query))
require.NotRegexp(t, "^http", pathToReturn, "pathWithQuery helper was used to create a URL") require.NotRegexp(t, "^http", pathToReturn, "pathWithQuery helper was used to create a URL")
return pathToReturn return pathToReturn
} }
urlWithQuery := func(baseURL string, query map[string]string) string { urlWithQuery := func(baseURL string, query map[string]string) string {
values := url.Values{} urlToReturn := fmt.Sprintf("%s?%s", baseURL, encodeQuery(query))
for k, v := range query {
values[k] = []string{v}
}
urlToReturn := fmt.Sprintf("%s?%s", baseURL, values.Encode())
_, err := url.Parse(urlToReturn) _, err := url.Parse(urlToReturn)
require.NoError(t, err, "urlWithQuery helper was used to create an illegal URL") require.NoError(t, err, "urlWithQuery helper was used to create an illegal URL")
return urlToReturn return urlToReturn
@ -189,26 +197,47 @@ func TestAuthorizationEndpoint(t *testing.T) {
return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap) return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap)
} }
// We're going to use this value to make assertions, so specify the exact expected value.
happyUpstreamStateParam, err := encoder.Encode("s",
// Ensure that the order of the serialized fields is exactly this order, so we can make simpler equality assertions below.
struct {
P string `json:"p"`
N string `json:"n"`
C string `json:"c"`
K string `json:"k"`
V string `json:"v"`
}{
P: encodeQuery(happyGetRequestQueryMap),
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: "1",
},
)
require.NoError(t, err)
happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(), happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(),
map[string]string{ map[string]string{
"response_type": "code", "response_type": "code",
"access_type": "offline", "access_type": "offline",
"scope": "scope1 scope2", "scope": "scope1 scope2",
"client_id": "some-client-id", "client_id": "some-client-id",
"state": "test-state", "state": happyUpstreamStateParam,
"nonce": "test-nonce", "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-idp", "redirect_uri": issuer + "/callback/some-idp",
}, },
) )
happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF)
type testCase struct { type testCase struct {
name string name string
issuer string issuer string
idpListGetter provider.DynamicUpstreamIDPProvider idpListGetter provider.DynamicUpstreamIDPProvider
generateState func() (state.State, error) generateCSRF func() (csrftoken.CSRFToken, error)
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
method string method string
@ -216,11 +245,12 @@ func TestAuthorizationEndpoint(t *testing.T) {
contentType string contentType string
body string body string
wantStatus int wantStatus int
wantContentType string wantContentType string
wantBodyString string wantBodyString string
wantBodyJSON string wantBodyJSON string
wantLocationHeader string wantLocationHeader string
wantCSRFCookieHeader string
} }
tests := []testCase{ tests := []testCase{
@ -228,7 +258,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "happy path using GET", name: "happy path using GET",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -239,13 +269,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
html.EscapeString(happyGetRequestExpectedRedirectLocation), html.EscapeString(happyGetRequestExpectedRedirectLocation),
"\n\n", "\n\n",
), ),
wantLocationHeader: happyGetRequestExpectedRedirectLocation, wantLocationHeader: happyGetRequestExpectedRedirectLocation,
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
}, },
{ {
name: "happy path using POST", name: "happy path using POST",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodPost, method: http.MethodPost,
@ -260,16 +291,17 @@ func TestAuthorizationEndpoint(t *testing.T) {
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURI}, "redirect_uri": []string{downstreamRedirectURI},
}.Encode(), }.Encode(),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "", wantContentType: "",
wantBodyString: "", wantBodyString: "",
wantLocationHeader: happyGetRequestExpectedRedirectLocation, wantLocationHeader: happyGetRequestExpectedRedirectLocation,
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
}, },
{ {
name: "downstream client does not exist", name: "downstream client does not exist",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -282,7 +314,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "downstream redirect uri does not match what is configured for client", name: "downstream redirect uri does not match what is configured for client",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -294,10 +326,10 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantBodyJSON: fositeInvalidRedirectURIErrorBody, wantBodyJSON: fositeInvalidRedirectURIErrorBody,
}, },
{ {
name: "downstream redirect uri matches what is configured for client except for the port number", name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -310,13 +342,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
html.EscapeString(happyGetRequestExpectedRedirectLocation), html.EscapeString(happyGetRequestExpectedRedirectLocation),
"\n\n", "\n\n",
), ),
wantLocationHeader: happyGetRequestExpectedRedirectLocation, wantLocationHeader: happyGetRequestExpectedRedirectLocation,
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
}, },
{ {
name: "response type is unsupported", name: "response type is unsupported",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -330,7 +363,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "downstream scopes do not match what is configured for client", name: "downstream scopes do not match what is configured for client",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -344,7 +377,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "missing response type in request", name: "missing response type in request",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -358,7 +391,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "missing client id in request", name: "missing client id in request",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -371,7 +404,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -385,7 +418,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -399,7 +432,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -413,7 +446,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -429,7 +462,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "prompt param is not allowed to have none and another legal value at the same time", name: "prompt param is not allowed to have none and another legal value at the same time",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -443,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "state does not have enough entropy", name: "state does not have enough entropy",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -454,23 +487,23 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantBodyString: "", wantBodyString: "",
}, },
{ {
name: "error while generating state", name: "error while generating CSRF token",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: func() (state.State, error) { return "", fmt.Errorf("some state generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
wantContentType: "text/plain; charset=utf-8", wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Internal Server Error: error generating state param\n", wantBodyString: "Internal Server Error: error generating CSRF token\n",
}, },
{ {
name: "error while generating nonce", name: "error while generating nonce",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
method: http.MethodGet, method: http.MethodGet,
@ -483,7 +516,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
name: "error while generating PKCE", name: "error while generating PKCE",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateState: happyStateGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, method: http.MethodGet,
@ -562,13 +595,23 @@ func TestAuthorizationEndpoint(t *testing.T) {
if test.wantLocationHeader != "" { if test.wantLocationHeader != "" {
actualLocation := rsp.Header().Get("Location") actualLocation := rsp.Header().Get("Location")
requireEqualURLs(t, actualLocation, test.wantLocationHeader) requireEqualURLs(t, actualLocation, test.wantLocationHeader)
} else {
require.Empty(t, rsp.Header().Values("Location"))
}
if test.wantCSRFCookieHeader != "" {
require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
actualCookie := rsp.Header().Get("Set-Cookie")
require.Equal(t, actualCookie, test.wantCSRFCookieHeader)
} else {
require.Empty(t, rsp.Header().Values("Set-Cookie"))
} }
} }
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateState, test.generatePKCE, test.generateNonce) subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder)
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
}) })
} }
@ -577,7 +620,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
test := tests[0] test := tests[0]
require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateState, test.generatePKCE, test.generateNonce) subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder)
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
@ -597,8 +640,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
"access_type": "offline", "access_type": "offline",
"scope": "other-scope1 other-scope2", "scope": "other-scope1 other-scope2",
"client_id": "some-other-client-id", "client_id": "some-other-client-id",
"state": "test-state", "state": happyUpstreamStateParam,
"nonce": "test-nonce", "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-other-idp", "redirect_uri": issuer + "/callback/some-other-idp",

View File

@ -0,0 +1,24 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package csrftoken
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
// Generate generates a new random CSRF token value.
func Generate() (CSRFToken, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (CSRFToken, error) {
var buf [32]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", fmt.Errorf("could not generate CSRFToken: %w", err)
}
return CSRFToken(hex.EncodeToString(buf[:])), nil
}
type CSRFToken string

View File

@ -8,6 +8,9 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/gorilla/securecookie"
"go.pinniped.dev/internal/oidc/csrftoken"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/storage" "github.com/ory/fosite/storage"
@ -18,7 +21,6 @@ import (
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/oidcclient/state"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -78,8 +80,13 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
} }
oauthHelper := oidc.FositeOauth2Helper(oauthStore, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret oauthHelper := oidc.FositeOauth2Helper(oauthStore, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret
var encoderHashKey = []byte("fake-hash-secret") // TODO fix this
var encoderBlockKey = []byte("16-bytes-aaaaaaa") // TODO fix this
var encoder = securecookie.New(encoderHashKey, encoderBlockKey)
encoder.SetSerializer(securecookie.JSONEncoder{})
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, state.Generate, pkce.Generate, nonce.Generate) m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder)
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
} }