Finish the WIP from the previous commit for saving authorize endpoint state

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-11 12:29:14 -08:00 committed by Ryan Richard
parent dd190dede6
commit c2262773e6
5 changed files with 273 additions and 158 deletions

View File

@ -26,13 +26,22 @@ const (
// Just in case we need to make a breaking change to the format of the upstream state param, // Just in case we need to make a breaking change to the format of the upstream state param,
// we are including a format version number. This gives the opportunity for a future version of Pinniped // we are including a format version number. This gives the opportunity for a future version of Pinniped
// to have the consumer of this format decide to reject versions that it doesn't understand. // to have the consumer of this format decide to reject versions that it doesn't understand.
stateParamFormatVersion = "1" upstreamStateParamFormatVersion = "1"
// The `name` passed to the encoder for encoding the upstream state param value. This name is short
// because it will be encoded into the upstream state param value and we're trying to keep that small.
upstreamStateParamEncodingName = "s"
// The name of the browser cookie which shall hold our CSRF value.
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
csrfCookieName = "__Host-pinniped-csrf"
) )
type IDPListGetter interface { type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider GetIDPList() []provider.UpstreamOIDCIdentityProvider
} }
// This is the encoding side of the securecookie.Codec interface.
type Encoder interface { type Encoder interface {
Encode(name string, value interface{}) (string, error) Encode(name string, value interface{}) (string, error)
} }
@ -63,7 +72,7 @@ func NewHandler(
upstreamIDP, err := chooseUpstreamIDP(idpListGetter) upstreamIDP, err := chooseUpstreamIDP(idpListGetter)
if err != nil { if err != nil {
plog.InfoErr("authorize request error", err) plog.WarningErr("authorize upstream config", err)
return err return err
} }
@ -91,7 +100,7 @@ func NewHandler(
csrfValue, nonceValue, pkceValue, err := generateValues(generateCSRF, generateNonce, generatePKCE) csrfValue, nonceValue, pkceValue, err := generateValues(generateCSRF, generateNonce, generatePKCE)
if err != nil { if err != nil {
plog.InfoErr("authorize generate error", err) plog.Error("authorize generate error", err)
return err return err
} }
@ -104,24 +113,13 @@ func NewHandler(
Scopes: upstreamIDP.Scopes, Scopes: upstreamIDP.Scopes,
} }
// `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, encoder)
http.SetCookie(w, &http.Cookie{ if err != nil {
Name: "__Host-pinniped-csrf", plog.Error("authorize upstream state param error", err)
Value: string(csrfValue), return err
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: true,
})
stateParamData := upstreamStateParamData{
AuthParams: authorizeRequester.GetRequestForm().Encode(),
Nonce: nonceValue,
CSRFToken: csrfValue,
PKCECode: pkceValue,
StateParamFormatVersion: stateParamFormatVersion,
} }
encodedStateParamValue, err := encoder.Encode("s", stateParamData)
// TODO handle the above error addCSRFSetCookieHeader(w, csrfValue)
http.Redirect(w, r, http.Redirect(w, r,
upstreamOAuthConfig.AuthCodeURL( upstreamOAuthConfig.AuthCodeURL(
@ -138,15 +136,6 @@ func NewHandler(
}) })
} }
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
type upstreamStateParamData struct {
AuthParams string `json:"p"`
Nonce nonce.Nonce `json:"n"`
CSRFToken csrftoken.CSRFToken `json:"c"`
PKCECode pkce.Code `json:"k"`
StateParamFormatVersion string `json:"v"`
}
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
allUpstreamIDPs := idpListGetter.GetIDPList() allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 { if len(allUpstreamIDPs) == 0 {
@ -170,22 +159,59 @@ func generateValues(
) (csrftoken.CSRFToken, nonce.Nonce, pkce.Code, error) { ) (csrftoken.CSRFToken, nonce.Nonce, pkce.Code, error) {
csrfValue, err := generateCSRF() csrfValue, err := generateCSRF()
if err != nil { if err != nil {
plog.InfoErr("error generating csrf param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating CSRF token", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating CSRF token", err)
} }
nonceValue, err := generateNonce() nonceValue, err := generateNonce()
if err != nil { if err != nil {
plog.InfoErr("error generating nonce param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating nonce param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating nonce param", err)
} }
pkceValue, err := generatePKCE() pkceValue, err := generatePKCE()
if err != nil { if err != nil {
plog.InfoErr("error generating PKCE param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err)
} }
return csrfValue, nonceValue, pkceValue, nil return csrfValue, nonceValue, pkceValue, nil
} }
// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
type upstreamStateParamData struct {
AuthParams string `json:"p"`
Nonce nonce.Nonce `json:"n"`
CSRFToken csrftoken.CSRFToken `json:"c"`
PKCECode pkce.Code `json:"k"`
StateParamFormatVersion string `json:"v"`
}
func upstreamStateParam(
authorizeRequester fosite.AuthorizeRequester,
nonceValue nonce.Nonce,
csrfValue csrftoken.CSRFToken,
pkceValue pkce.Code,
encoder Encoder,
) (string, error) {
stateParamData := upstreamStateParamData{
AuthParams: authorizeRequester.GetRequestForm().Encode(),
Nonce: nonceValue,
CSRFToken: csrfValue,
PKCECode: pkceValue,
StateParamFormatVersion: upstreamStateParamFormatVersion,
}
encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData)
if err != nil {
return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err)
}
return encodedStateParamValue, nil
}
func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken) {
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: string(csrfValue),
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: true,
})
}
func fositeErrorForLog(err error) []interface{} { func fositeErrorForLog(err error) []interface{} {
rfc6749Error := fosite.ErrorToRFC6749Error(err) rfc6749Error := fosite.ErrorToRFC6749Error(err)
keysAndValues := make([]interface{}, 0) keysAndValues := make([]interface{}, 0)

View File

@ -28,7 +28,8 @@ import (
func TestAuthorizationEndpoint(t *testing.T) { func TestAuthorizationEndpoint(t *testing.T) {
const ( const (
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback"
) )
var ( var (
@ -144,8 +145,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g" expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
var encoderHashKey = []byte("fake-hash-secret") var encoderHashKey = []byte("fake-hash-secret")
var encoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption var happyEncoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption
encoder.SetSerializer(securecookie.JSONEncoder{}) happyEncoder.SetSerializer(securecookie.JSONEncoder{})
encodeQuery := func(query map[string]string) string { encodeQuery := func(query map[string]string) string {
values := url.Values{} values := url.Values{}
@ -168,22 +169,24 @@ func TestAuthorizationEndpoint(t *testing.T) {
return urlToReturn return urlToReturn
} }
happyGetRequestQueryMap := map[string]string{ happyGetRequestQueryMap := func(downstreamRedirectURI string) map[string]string {
"response_type": "code", return map[string]string{
"scope": "openid profile email", "response_type": "code",
"client_id": "pinniped-cli", "scope": "openid profile email",
"state": "some-state-value", "client_id": "pinniped-cli",
"nonce": "some-nonce-value", "state": "some-state-value",
"code_challenge": "some-challenge", "nonce": "some-nonce-value",
"code_challenge_method": "S256", "code_challenge": "some-challenge",
"redirect_uri": downstreamRedirectURI, "code_challenge_method": "S256",
"redirect_uri": downstreamRedirectURI,
}
} }
happyGetRequestPath := pathWithQuery("/some/path", happyGetRequestQueryMap) happyGetRequestPath := pathWithQuery("/some/path", happyGetRequestQueryMap(downstreamRedirectURI))
modifiedHappyGetRequestPath := func(queryOverrides map[string]string) string { modifiedHappyGetRequestPath := func(queryOverrides map[string]string) string {
copyOfHappyGetRequestQueryMap := map[string]string{} copyOfHappyGetRequestQueryMap := map[string]string{}
for k, v := range happyGetRequestQueryMap { for k, v := range happyGetRequestQueryMap(downstreamRedirectURI) {
copyOfHappyGetRequestQueryMap[k] = v copyOfHappyGetRequestQueryMap[k] = v
} }
for k, v := range queryOverrides { for k, v := range queryOverrides {
@ -197,38 +200,33 @@ func TestAuthorizationEndpoint(t *testing.T) {
return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap) return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap)
} }
// We're going to use this value to make assertions, so specify the exact expected value. happyExpectedUpstreamStateParam := func(downstreamRedirectURI string) string {
happyUpstreamStateParam, err := encoder.Encode("s", encoded, err := happyEncoder.Encode("s",
// Ensure that the order of the serialized fields is exactly this order, so we can make simpler equality assertions below. expectedUpstreamStateParamFormat{
struct { P: encodeQuery(happyGetRequestQueryMap(downstreamRedirectURI)),
P string `json:"p"` N: happyNonce,
N string `json:"n"` C: happyCSRF,
C string `json:"c"` K: happyPKCE,
K string `json:"k"` V: "1",
V string `json:"v"` },
}{ )
P: encodeQuery(happyGetRequestQueryMap), require.NoError(t, err)
N: happyNonce, return encoded
C: happyCSRF, }
K: happyPKCE,
V: "1",
},
)
require.NoError(t, err)
happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(), happyExpectedRedirectLocation := func(downstreamRedirectURI string) string {
map[string]string{ return urlWithQuery(upstreamAuthURL.String(), map[string]string{
"response_type": "code", "response_type": "code",
"access_type": "offline", "access_type": "offline",
"scope": "scope1 scope2", "scope": "scope1 scope2",
"client_id": "some-client-id", "client_id": "some-client-id",
"state": happyUpstreamStateParam, "state": happyExpectedUpstreamStateParam(downstreamRedirectURI),
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-idp", "redirect_uri": issuer + "/callback/some-idp",
}, })
) }
happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF) happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF)
@ -240,6 +238,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF func() (csrftoken.CSRFToken, error) generateCSRF func() (csrftoken.CSRFToken, error)
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
encoder securecookie.Codec
method string method string
path string path string
contentType string contentType string
@ -251,8 +250,9 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantBodyJSON string wantBodyJSON string
wantLocationHeader string wantLocationHeader string
wantCSRFCookieHeader string wantCSRFCookieHeader string
}
wantUpstreamStateParamInLocationHeader bool
}
tests := []testCase{ tests := []testCase{
{ {
name: "happy path using GET", name: "happy path using GET",
@ -261,54 +261,59 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8", wantContentType: "text/html; charset=utf-8",
wantBodyString: fmt.Sprintf(`<a href="%s">Found</a>.%s`, wantBodyString: fmt.Sprintf(`<a href="%s">Found</a>.%s`,
html.EscapeString(happyGetRequestExpectedRedirectLocation), html.EscapeString(happyExpectedRedirectLocation(downstreamRedirectURI)),
"\n\n", "\n\n",
), ),
wantLocationHeader: happyGetRequestExpectedRedirectLocation, wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURI),
wantUpstreamStateParamInLocationHeader: true,
}, },
{ {
name: "happy path using POST", name: "happy path using POST",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodPost,
path: "/some/path",
contentType: "application/x-www-form-urlencoded",
body: encodeQuery(happyGetRequestQueryMap(downstreamRedirectURI)),
wantStatus: http.StatusFound,
wantContentType: "",
wantBodyString: "",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURI),
wantUpstreamStateParamInLocationHeader: true,
},
{
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodPost, encoder: happyEncoder,
path: "/some/path", method: http.MethodGet,
contentType: "application/x-www-form-urlencoded", path: modifiedHappyGetRequestPath(map[string]string{
body: url.Values{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
"response_type": []string{"code"}, }),
"scope": []string{"openid profile email"}, wantStatus: http.StatusFound,
"client_id": []string{"pinniped-cli"}, wantContentType: "text/html; charset=utf-8",
"state": []string{"some-state-value"}, wantBodyString: fmt.Sprintf(`<a href="%s">Found</a>.%s`,
"code_challenge": []string{"some-challenge"}, html.EscapeString(happyExpectedRedirectLocation(downstreamRedirectURIWithDifferentPort)),
"code_challenge_method": []string{"S256"}, "\n\n",
"redirect_uri": []string{downstreamRedirectURI}, ),
}.Encode(), wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
wantStatus: http.StatusFound, wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURIWithDifferentPort),
wantContentType: "", wantUpstreamStateParamInLocationHeader: true,
wantBodyString: "",
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
},
{
name: "downstream client does not exist",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}),
wantStatus: http.StatusUnauthorized,
wantContentType: "application/json; charset=utf-8",
wantBodyJSON: fositeInvalidClientErrorBody,
}, },
{ {
name: "downstream redirect uri does not match what is configured for client", name: "downstream redirect uri does not match what is configured for client",
@ -317,6 +322,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{ path: modifiedHappyGetRequestPath(map[string]string{
"redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client", "redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client",
@ -326,24 +332,18 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantBodyJSON: fositeInvalidRedirectURIErrorBody, wantBodyJSON: fositeInvalidRedirectURIErrorBody,
}, },
{ {
name: "happy path when downstream redirect uri matches what is configured for client except for the port number", name: "downstream client does not exist",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
method: http.MethodGet, encoder: happyEncoder,
path: modifiedHappyGetRequestPath(map[string]string{ method: http.MethodGet,
"redirect_uri": "http://127.0.0.1:42/callback", path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}),
}), wantStatus: http.StatusUnauthorized,
wantStatus: http.StatusFound, wantContentType: "application/json; charset=utf-8",
wantContentType: "text/html; charset=utf-8", wantBodyJSON: fositeInvalidClientErrorBody,
wantBodyString: fmt.Sprintf(`<a href="%s">Found</a>.%s`,
html.EscapeString(happyGetRequestExpectedRedirectLocation),
"\n\n",
),
wantLocationHeader: happyGetRequestExpectedRedirectLocation,
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
}, },
{ {
name: "response type is unsupported", name: "response type is unsupported",
@ -352,6 +352,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}), path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -366,6 +367,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}), path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -380,6 +382,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}), path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -394,6 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}), path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
@ -407,6 +411,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -421,6 +426,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -435,6 +441,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -449,6 +456,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}), path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -465,6 +473,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}), path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -479,6 +488,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}), path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}),
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
@ -486,6 +496,20 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantLocationHeader: urlWithQuery(downstreamRedirectURI, fositeInvalidStateErrorQuery), wantLocationHeader: urlWithQuery(downstreamRedirectURI, fositeInvalidStateErrorQuery),
wantBodyString: "", wantBodyString: "",
}, },
{
name: "error while encoding upstream state param",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
encoder: &errorReturningEncoder{},
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError,
wantContentType: "text/plain; charset=utf-8",
wantBodyString: "Internal Server Error: error encoding upstream state param\n",
},
{ {
name: "error while generating CSRF token", name: "error while generating CSRF token",
issuer: issuer, issuer: issuer,
@ -493,6 +517,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
@ -506,6 +531,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
@ -519,6 +545,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
encoder: happyEncoder,
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError, wantStatus: http.StatusInternalServerError,
@ -586,19 +613,22 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
if test.wantLocationHeader != "" {
actualLocation := rsp.Header().Get("Location")
if test.wantUpstreamStateParamInLocationHeader {
requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.encoder)
}
requireEqualURLs(t, actualLocation, test.wantLocationHeader)
} else {
require.Empty(t, rsp.Header().Values("Location"))
}
if test.wantBodyJSON != "" { if test.wantBodyJSON != "" {
require.JSONEq(t, test.wantBodyJSON, rsp.Body.String()) require.JSONEq(t, test.wantBodyJSON, rsp.Body.String())
} else { } else {
require.Equal(t, test.wantBodyString, rsp.Body.String()) require.Equal(t, test.wantBodyString, rsp.Body.String())
} }
if test.wantLocationHeader != "" {
actualLocation := rsp.Header().Get("Location")
requireEqualURLs(t, actualLocation, test.wantLocationHeader)
} else {
require.Empty(t, rsp.Header().Values("Location"))
}
if test.wantCSRFCookieHeader != "" { if test.wantCSRFCookieHeader != "" {
require.Len(t, rsp.Header().Values("Set-Cookie"), 1) require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
actualCookie := rsp.Header().Get("Set-Cookie") actualCookie := rsp.Header().Get("Set-Cookie")
@ -611,7 +641,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder) subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder)
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
}) })
} }
@ -620,7 +650,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
test := tests[0] test := tests[0]
require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case
subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder) subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder)
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
@ -640,7 +670,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
"access_type": "offline", "access_type": "offline",
"scope": "other-scope1 other-scope2", "scope": "other-scope1 other-scope2",
"client_id": "some-other-client-id", "client_id": "some-other-client-id",
"state": happyUpstreamStateParam, "state": happyExpectedUpstreamStateParam(downstreamRedirectURI),
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
@ -660,6 +690,26 @@ func TestAuthorizationEndpoint(t *testing.T) {
}) })
} }
// Declare a separate type from the production code to ensure that the state param's contents was serialized
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
// assertions about the redirect URL in this test.
type expectedUpstreamStateParamFormat struct {
P string `json:"p"`
N string `json:"n"`
C string `json:"c"`
K string `json:"k"`
V string `json:"v"`
}
type errorReturningEncoder struct {
securecookie.Codec
}
func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) {
return "", fmt.Errorf("some encoding error")
}
func requireEqualContentType(t *testing.T, actual string, expected string) { func requireEqualContentType(t *testing.T, actual string, expected string) {
t.Helper() t.Helper()
@ -676,6 +726,28 @@ func requireEqualContentType(t *testing.T, actual string, expected string) {
require.Equal(t, actualContentTypeParams, expectedContentTypeParams) require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
} }
func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) {
t.Helper()
actualLocationURL, err := url.Parse(actualURL)
require.NoError(t, err)
expectedLocationURL, err := url.Parse(expectedURL)
require.NoError(t, err)
expectedQueryStateParam := expectedLocationURL.Query().Get("state")
require.NotEmpty(t, expectedQueryStateParam)
var expectedDecodedStateParam expectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
require.NoError(t, err)
actualQueryStateParam := actualLocationURL.Query().Get("state")
require.NotEmpty(t, actualQueryStateParam)
var actualDecodedStateParam expectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
require.NoError(t, err)
require.Equal(t, expectedDecodedStateParam, actualDecodedStateParam)
}
func requireEqualURLs(t *testing.T, actualURL string, expectedURL string) { func requireEqualURLs(t *testing.T, actualURL string, expectedURL string) {
t.Helper() t.Helper()
actualLocationURL, err := url.Parse(actualURL) actualLocationURL, err := url.Parse(actualURL)

View File

@ -9,13 +9,12 @@ import (
"sync" "sync"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"go.pinniped.dev/internal/oidc/csrftoken"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/storage" "github.com/ory/fosite/storage"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/auth" "go.pinniped.dev/internal/oidc/auth"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"

View File

@ -1,22 +0,0 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package plog implements a thin layer over klog to help enforce pinniped's logging convention.
// Logs are always structured as a constant message with key and value pairs of related metadata.
// The logging levels in order of increasing verbosity are:
// error, warning, info, debug, trace and all.
// error and warning logs are always emitted (there is no way for the end user to disable them),
// and thus should be used sparingly. Ideally, logs at these levels should be actionable.
// info should be reserved for "nice to know" information. It should be possible to run a production
// pinniped server at the info log level with no performance degradation due to high log volume.
// debug should be used for information targeted at developers and to aid in support cases. Care must
// be taken at this level to not leak any secrets into the log stream. That is, even though debug may
// cause performance issues in production, it must not cause security issues in production.
// trace should be used to log information related to timing (i.e. the time it took a controller to sync).
// Just like debug, trace should not leak secrets into the log stream. trace will likely leak information
// about the current state of the process, but that, along with performance degradation, is expected.
// all is reserved for the most verbose and security sensitive information. At this level, full request
// metadata such as headers and parameters along with the body may be logged. This level is completely
// unfit for production use both from a performance and security standpoint. Using it is generally an
// act of desperation to determine why the system is broken.
package plog

View File

@ -1,12 +1,37 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package plog implements a thin layer over klog to help enforce pinniped's logging convention.
// Logs are always structured as a constant message with key and value pairs of related metadata.
//
// The logging levels in order of increasing verbosity are:
// error, warning, info, debug, trace and all.
//
// error and warning logs are always emitted (there is no way for the end user to disable them),
// and thus should be used sparingly. Ideally, logs at these levels should be actionable.
//
// info should be reserved for "nice to know" information. It should be possible to run a production
// pinniped server at the info log level with no performance degradation due to high log volume.
// debug should be used for information targeted at developers and to aid in support cases. Care must
// be taken at this level to not leak any secrets into the log stream. That is, even though debug may
// cause performance issues in production, it must not cause security issues in production.
//
// trace should be used to log information related to timing (i.e. the time it took a controller to sync).
// Just like debug, trace should not leak secrets into the log stream. trace will likely leak information
// about the current state of the process, but that, along with performance degradation, is expected.
//
// all is reserved for the most verbose and security sensitive information. At this level, full request
// metadata such as headers and parameters along with the body may be logged. This level is completely
// unfit for production use both from a performance and security standpoint. Using it is generally an
// act of desperation to determine why the system is broken.
package plog package plog
import "k8s.io/klog/v2" import "k8s.io/klog/v2"
const errorKey = "error"
// Use Error to log an unexpected system error. // Use Error to log an unexpected system error.
func Error(err error, msg string, keysAndValues ...interface{}) { func Error(msg string, err error, keysAndValues ...interface{}) {
klog.ErrorS(err, msg, keysAndValues...) klog.ErrorS(err, msg, keysAndValues...)
} }
@ -19,23 +44,38 @@ func Warning(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelWarning).InfoS(msg, keysAndValues...) klog.V(klogLevelWarning).InfoS(msg, keysAndValues...)
} }
// Use WarningErr to issue a Warning message with an error object as part of the message.
func WarningErr(msg string, err error, keysAndValues ...interface{}) {
Warning(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
}
func Info(msg string, keysAndValues ...interface{}) { func Info(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelInfo).InfoS(msg, keysAndValues...) klog.V(klogLevelInfo).InfoS(msg, keysAndValues...)
} }
// Use InfoErr to log an expected error, e.g. validation failure of an http parameter. // Use InfoErr to log an expected error, e.g. validation failure of an http parameter.
func InfoErr(msg string, err error, keysAndValues ...interface{}) { func InfoErr(msg string, err error, keysAndValues ...interface{}) {
klog.V(klogLevelInfo).InfoS(msg, append([]interface{}{"error", err}, keysAndValues)...) Info(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
} }
func Debug(msg string, keysAndValues ...interface{}) { func Debug(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelDebug).InfoS(msg, keysAndValues...) klog.V(klogLevelDebug).InfoS(msg, keysAndValues...)
} }
// Use DebugErr to issue a Debug message with an error object as part of the message.
func DebugErr(msg string, err error, keysAndValues ...interface{}) {
Debug(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
}
func Trace(msg string, keysAndValues ...interface{}) { func Trace(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelTrace).InfoS(msg, keysAndValues...) klog.V(klogLevelTrace).InfoS(msg, keysAndValues...)
} }
// Use TraceErr to issue a Trace message with an error object as part of the message.
func TraceErr(msg string, err error, keysAndValues ...interface{}) {
Trace(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
}
func All(msg string, keysAndValues ...interface{}) { func All(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelAll).InfoS(msg, keysAndValues...) klog.V(klogLevelAll).InfoS(msg, keysAndValues...)
} }