Supervisor authorize endpoint reuses existing CSRF cookies and signs new ones

- To better support having multiple downstream providers configured,
  the authorize endpoint will share a CSRF cookie between all
  downstream providers' authorize endpoints. The first time a
  user's browser hits the authorize endpoint of any downstream
  provider, that endpoint will set the cookie. Then if the user
  starts an authorize flow with that same downstream provider or with
  any other downstream provider which shares the same domain name
  (i.e. differentiated by issuer path), then the same cookie will be
  submitted and respected.
- Just in case we are sharing the domain name with some other app,
  we sign the value of any new CSRF cookie and check the signature
  when we receive the cookie. This wasn't strictly necessary since
  we probably won't share a domain name with other apps, but it
  wasn't hard to add this cookie signing.

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-12 15:36:59 -08:00 committed by Ryan Richard
parent d73fdb1d33
commit 080bb594b2
3 changed files with 217 additions and 67 deletions

View File

@ -9,6 +9,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/gorilla/securecookie"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt" "github.com/ory/fosite/token/jwt"
@ -35,6 +37,9 @@ const (
// The name of the browser cookie which shall hold our CSRF value. // The name of the browser cookie which shall hold our CSRF value.
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
csrfCookieName = "__Host-pinniped-csrf" csrfCookieName = "__Host-pinniped-csrf"
// The `name` passed to the encoder for encoding and decoding the CSRF cookie contents.
csrfCookieEncodingName = "csrf"
) )
type IDPListGetter interface { type IDPListGetter interface {
@ -53,7 +58,8 @@ func NewHandler(
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
encoder Encoder, upstreamStateEncoder Encoder,
cookieCodec securecookie.Codec,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost && r.Method != http.MethodGet { if r.Method != http.MethodPost && r.Method != http.MethodGet {
@ -63,6 +69,12 @@ func NewHandler(
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method)
} }
csrfFromCookie, err := readCSRFCookie(r, cookieCodec)
if err != nil {
plog.InfoErr("error reading CSRF cookie", err)
return err
}
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r) authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r)
if err != nil { if err != nil {
plog.Info("authorize request error", fositeErrorForLog(err)...) plog.Info("authorize request error", fositeErrorForLog(err)...)
@ -77,11 +89,7 @@ func NewHandler(
} }
// Grant the openid scope (for now) if they asked for it so that `NewAuthorizeResponse` will perform its OIDC validations. // Grant the openid scope (for now) if they asked for it so that `NewAuthorizeResponse` will perform its OIDC validations.
for _, scope := range authorizeRequester.GetRequestedScopes() { grantOpenIDScopeIfRequested(authorizeRequester)
if scope == "openid" {
authorizeRequester.GrantScope(scope)
}
}
now := time.Now() now := time.Now()
_, err = oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ _, err = oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{
@ -103,6 +111,9 @@ func NewHandler(
plog.Error("authorize generate error", err) plog.Error("authorize generate error", err)
return err return err
} }
if csrfFromCookie != "" {
csrfValue = csrfFromCookie
}
upstreamOAuthConfig := oauth2.Config{ upstreamOAuthConfig := oauth2.Config{
ClientID: upstreamIDP.ClientID, ClientID: upstreamIDP.ClientID,
@ -113,13 +124,20 @@ func NewHandler(
Scopes: upstreamIDP.Scopes, Scopes: upstreamIDP.Scopes,
} }
encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, encoder) encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder)
if err != nil { if err != nil {
plog.Error("authorize upstream state param error", err) plog.Error("authorize upstream state param error", err)
return err return err
} }
addCSRFSetCookieHeader(w, csrfValue) if csrfFromCookie == "" {
// We did not receive an incoming CSRF cookie, so write a new one.
err := addCSRFSetCookieHeader(w, csrfValue, cookieCodec)
if err != nil {
plog.Error("error setting CSRF cookie", err)
return err
}
}
http.Redirect(w, r, http.Redirect(w, r,
upstreamOAuthConfig.AuthCodeURL( upstreamOAuthConfig.AuthCodeURL(
@ -136,6 +154,30 @@ func NewHandler(
}) })
} }
func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(csrfCookieName)
if err != nil {
// Error means that the cookie was not found
return "", nil
}
var csrfFromCookie csrftoken.CSRFToken
err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil {
return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err)
}
return csrfFromCookie, nil
}
func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
for _, scope := range authorizeRequester.GetRequestedScopes() {
if scope == "openid" {
authorizeRequester.GrantScope(scope)
}
}
}
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
allUpstreamIDPs := idpListGetter.GetIDPList() allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 { if len(allUpstreamIDPs) == 0 {
@ -202,14 +244,21 @@ func upstreamStateParam(
return encodedStateParamValue, nil return encodedStateParamValue, nil
} }
func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken) { func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error {
encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue)
if err != nil {
return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err)
}
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: csrfCookieName, Name: csrfCookieName,
Value: string(csrfValue), Value: encodedCSRFValue,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
Secure: true, Secure: true,
}) })
return nil
} }
func fositeErrorForLog(err error) []interface{} { func fositeErrorForLog(err error) []interface{} {

View File

@ -10,6 +10,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"regexp"
"strings" "strings"
"testing" "testing"
@ -137,10 +138,17 @@ func TestAuthorizationEndpoint(t *testing.T) {
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g" expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
var encoderHashKey = []byte("fake-hash-secret") var stateEncoderHashKey = []byte("fake-hash-secret")
var encoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
var happyEncoder = securecookie.New(encoderHashKey, encoderBlockKey) // note that nil block key argument turns off encryption var cookieEncoderHashKey = []byte("fake-hash-secret2")
happyEncoder.SetSerializer(securecookie.JSONEncoder{}) var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES
require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey)
require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey)
var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey)
happyStateEncoder.SetSerializer(securecookie.JSONEncoder{})
var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{})
encodeQuery := func(query map[string]string) string { encodeQuery := func(query map[string]string) string {
values := url.Values{} values := url.Values{}
@ -196,12 +204,16 @@ func TestAuthorizationEndpoint(t *testing.T) {
return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides)) return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides))
} }
expectedUpstreamStateParam := func(queryOverrides map[string]string) string { expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string {
encoded, err := happyEncoder.Encode("s", csrf := happyCSRF
if csrfValueOverride != "" {
csrf = csrfValueOverride
}
encoded, err := happyStateEncoder.Encode("s",
expectedUpstreamStateParamFormat{ expectedUpstreamStateParamFormat{
P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)),
N: happyNonce, N: happyNonce,
C: happyCSRF, C: csrf,
K: happyPKCE, K: happyPKCE,
V: "1", V: "1",
}, },
@ -224,7 +236,9 @@ func TestAuthorizationEndpoint(t *testing.T) {
}) })
} }
happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF) incomingCookieCSRFValue := "csrf-value-from-cookie"
encodedIncomingCookieCSRFValue, err := happyCookieEncoder.Encode("csrf", incomingCookieCSRFValue)
require.NoError(t, err)
type testCase struct { type testCase struct {
name string name string
@ -234,37 +248,58 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF func() (csrftoken.CSRFToken, error) generateCSRF func() (csrftoken.CSRFToken, error)
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
encoder securecookie.Codec stateEncoder securecookie.Codec
cookieEncoder securecookie.Codec
method string method string
path string path string
contentType string contentType string
body string body string
csrfCookie string
wantStatus int wantStatus int
wantContentType string wantContentType string
wantBodyString string wantBodyString string
wantBodyJSON string wantBodyJSON string
wantLocationHeader string wantLocationHeader string
wantCSRFCookieHeader string wantCSRFValueInCookieHeader string
wantUpstreamStateParamInLocationHeader bool wantUpstreamStateParamInLocationHeader bool
wantBodyStringWithLocationInHref bool wantBodyStringWithLocationInHref bool
} }
tests := []testCase{ tests := []testCase{
{ {
name: "happy path using GET", name: "happy path using GET without a CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil)), wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")),
wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true,
},
{
name: "happy path using GET with a CSRF cookie",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue,
wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8",
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
@ -275,7 +310,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodPost, method: http.MethodPost,
path: "/some/path", path: "/some/path",
contentType: "application/x-www-form-urlencoded", contentType: "application/x-www-form-urlencoded",
@ -283,8 +319,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "", wantContentType: "",
wantBodyString: "", wantBodyString: "",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil)), wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
}, },
{ {
@ -294,17 +330,18 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{ path: modifiedHappyGetRequestPath(map[string]string{
"redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
}), }),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{
"redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
})), }, "")),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
@ -315,7 +352,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{ path: modifiedHappyGetRequestPath(map[string]string{
"redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client", "redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client",
@ -331,7 +369,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}), path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
@ -345,7 +384,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}), path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -360,7 +400,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}), path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -375,7 +416,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}), path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -390,7 +432,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}), path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
@ -404,7 +447,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -419,7 +463,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -434,7 +479,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -449,7 +495,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -466,7 +513,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}), path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -481,14 +529,17 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
// The following prompt value is illegal when openid is requested, but note that openid is not requested. // The following prompt value is illegal when openid is requested, but note that openid is not requested.
path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}), path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, wantCSRFValueInCookieHeader: happyCSRF,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{"prompt": "none login", "scope": "email"})), wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(
map[string]string{"prompt": "none login", "scope": "email"}, "",
)),
wantUpstreamStateParamInLocationHeader: true, wantUpstreamStateParamInLocationHeader: true,
wantBodyStringWithLocationInHref: true, wantBodyStringWithLocationInHref: true,
}, },
@ -499,7 +550,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}), path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -514,13 +566,29 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: &errorReturningEncoder{}, stateEncoder: &errorReturningEncoder{},
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
wantContentType: "text/plain; charset=utf-8", wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Internal Server Error: error encoding upstream state param\n", wantBodyString: "Internal Server Error: error encoding upstream state param\n",
}, },
{
name: "error while encoding CSRF cookie value for new cookie",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
stateEncoder: happyStateEncoder,
cookieEncoder: &errorReturningEncoder{},
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError,
wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Internal Server Error: error encoding CSRF cookie\n",
},
{ {
name: "error while generating CSRF token", name: "error while generating CSRF token",
issuer: issuer, issuer: issuer,
@ -528,7 +596,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
@ -542,7 +611,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
@ -556,13 +626,30 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder, stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
wantContentType: "text/plain; charset=utf-8", wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Internal Server Error: error generating PKCE param\n", wantBodyString: "Internal Server Error: error generating PKCE param\n",
}, },
{
name: "error while decoding CSRF cookie",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
stateEncoder: happyStateEncoder,
cookieEncoder: happyCookieEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusUnprocessableEntity,
wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Unprocessable Entity: error reading CSRF cookie\n",
},
{ {
name: "no upstream providers are configured", name: "no upstream providers are configured",
issuer: issuer, issuer: issuer,
@ -618,6 +705,9 @@ func TestAuthorizationEndpoint(t *testing.T) {
runOneTestCase := func(t *testing.T, test testCase, subject http.Handler) { runOneTestCase := func(t *testing.T, test testCase, subject http.Handler) {
req := httptest.NewRequest(test.method, test.path, strings.NewReader(test.body)) req := httptest.NewRequest(test.method, test.path, strings.NewReader(test.body))
req.Header.Set("Content-Type", test.contentType) req.Header.Set("Content-Type", test.contentType)
if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie)
}
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
@ -627,7 +717,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
actualLocation := rsp.Header().Get("Location") actualLocation := rsp.Header().Get("Location")
if test.wantLocationHeader != "" { if test.wantLocationHeader != "" {
if test.wantUpstreamStateParamInLocationHeader { if test.wantUpstreamStateParamInLocationHeader {
requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.encoder) requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.stateEncoder)
} }
// The upstream state param is encoded using a timestamp at the beginning so we don't want to // The upstream state param is encoded using a timestamp at the beginning so we don't want to
// compare those states since they may be different, but we do want to compare the downstream // compare those states since they may be different, but we do want to compare the downstream
@ -647,10 +737,17 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantBodyString, rsp.Body.String()) require.Equal(t, test.wantBodyString, rsp.Body.String())
} }
if test.wantCSRFCookieHeader != "" { if test.wantCSRFValueInCookieHeader != "" {
require.Len(t, rsp.Header().Values("Set-Cookie"), 1) require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
actualCookie := rsp.Header().Get("Set-Cookie") actualCookie := rsp.Header().Get("Set-Cookie")
require.Equal(t, actualCookie, test.wantCSRFCookieHeader) regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); HttpOnly; Secure; SameSite=Strict")
submatches := regex.FindStringSubmatch(actualCookie)
require.Len(t, submatches, 2)
captured := submatches[1]
var decodedCSRFCookieValue string
err := test.cookieEncoder.Decode("csrf", captured, &decodedCSRFCookieValue)
require.NoError(t, err)
require.Equal(t, test.wantCSRFValueInCookieHeader, decodedCSRFCookieValue)
} else { } else {
require.Empty(t, rsp.Header().Values("Set-Cookie")) require.Empty(t, rsp.Header().Values("Set-Cookie"))
} }
@ -659,16 +756,16 @@ func TestAuthorizationEndpoint(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder) subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder)
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
}) })
} }
t.Run("allows upstream provider configuration to change between requests", func(t *testing.T) { t.Run("allows upstream provider configuration to change between requests", func(t *testing.T) {
test := tests[0] test := tests[0]
require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case require.Equal(t, "happy path using GET without a CSRF cookie", test.name) // re-use the happy path test case
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder) subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder)
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
@ -688,7 +785,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
"access_type": "offline", "access_type": "offline",
"scope": "other-scope1 other-scope2", "scope": "other-scope1 other-scope2",
"client_id": "some-other-client-id", "client_id": "some-other-client-id",
"state": expectedUpstreamStateParam(nil), "state": expectedUpstreamStateParam(nil, ""),
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",

View File

@ -72,13 +72,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
// the upstream callback endpoint is called later. // the upstream callback endpoint is called later.
oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret
// TODO use different codecs for the state and the cookie, because:
// 1. we would like to state to have an embedded expiration date while the cookie does not need that
// 2. we would like each downstream provider to use different secrets for signing/encrypting the upstream state, not share secrets
// 3. we would like *all* downstream providers to use the *same* signing key for the CSRF cookie (which doesn't need to be encrypted) because cookies are sent per-domain and our issuers can share a domain name (but have different paths)
var encoderHashKey = []byte("fake-hash-secret") // TODO replace this secret var encoderHashKey = []byte("fake-hash-secret") // TODO replace this secret
var encoderBlockKey = []byte("16-bytes-aaaaaaa") // TODO replace this secret var encoderBlockKey = []byte("16-bytes-aaaaaaa") // TODO replace this secret
var encoder = securecookie.New(encoderHashKey, encoderBlockKey) var encoder = securecookie.New(encoderHashKey, encoderBlockKey)
encoder.SetSerializer(securecookie.JSONEncoder{}) encoder.SetSerializer(securecookie.JSONEncoder{})
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder) m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder)
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
} }