diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go
index f562a393..b47d4984 100644
--- a/internal/oidc/auth/auth_handler.go
+++ b/internal/oidc/auth/auth_handler.go
@@ -26,13 +26,22 @@ const (
// Just in case we need to make a breaking change to the format of the upstream state param,
// we are including a format version number. This gives the opportunity for a future version of Pinniped
// to have the consumer of this format decide to reject versions that it doesn't understand.
- stateParamFormatVersion = "1"
+ upstreamStateParamFormatVersion = "1"
+
+ // The `name` passed to the encoder for encoding the upstream state param value. This name is short
+ // because it will be encoded into the upstream state param value and we're trying to keep that small.
+ upstreamStateParamEncodingName = "s"
+
+ // The name of the browser cookie which shall hold our CSRF value.
+ // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
+ csrfCookieName = "__Host-pinniped-csrf"
)
type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider
}
+// This is the encoding side of the securecookie.Codec interface.
type Encoder interface {
Encode(name string, value interface{}) (string, error)
}
@@ -63,7 +72,7 @@ func NewHandler(
upstreamIDP, err := chooseUpstreamIDP(idpListGetter)
if err != nil {
- plog.InfoErr("authorize request error", err)
+ plog.WarningErr("authorize upstream config", err)
return err
}
@@ -91,7 +100,7 @@ func NewHandler(
csrfValue, nonceValue, pkceValue, err := generateValues(generateCSRF, generateNonce, generatePKCE)
if err != nil {
- plog.InfoErr("authorize generate error", err)
+ plog.Error("authorize generate error", err)
return err
}
@@ -104,24 +113,13 @@ func NewHandler(
Scopes: upstreamIDP.Scopes,
}
- // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes
- http.SetCookie(w, &http.Cookie{
- Name: "__Host-pinniped-csrf",
- Value: string(csrfValue),
- HttpOnly: true,
- SameSite: http.SameSiteStrictMode,
- Secure: true,
- })
-
- stateParamData := upstreamStateParamData{
- AuthParams: authorizeRequester.GetRequestForm().Encode(),
- Nonce: nonceValue,
- CSRFToken: csrfValue,
- PKCECode: pkceValue,
- StateParamFormatVersion: stateParamFormatVersion,
+ encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, encoder)
+ if err != nil {
+ plog.Error("authorize upstream state param error", err)
+ return err
}
- encodedStateParamValue, err := encoder.Encode("s", stateParamData)
- // TODO handle the above error
+
+ addCSRFSetCookieHeader(w, csrfValue)
http.Redirect(w, r,
upstreamOAuthConfig.AuthCodeURL(
@@ -138,15 +136,6 @@ func NewHandler(
})
}
-// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
-type upstreamStateParamData struct {
- AuthParams string `json:"p"`
- Nonce nonce.Nonce `json:"n"`
- CSRFToken csrftoken.CSRFToken `json:"c"`
- PKCECode pkce.Code `json:"k"`
- StateParamFormatVersion string `json:"v"`
-}
-
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 {
@@ -170,22 +159,59 @@ func generateValues(
) (csrftoken.CSRFToken, nonce.Nonce, pkce.Code, error) {
csrfValue, err := generateCSRF()
if err != nil {
- plog.InfoErr("error generating csrf param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating CSRF token", err)
}
nonceValue, err := generateNonce()
if err != nil {
- plog.InfoErr("error generating nonce param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating nonce param", err)
}
pkceValue, err := generatePKCE()
if err != nil {
- plog.InfoErr("error generating PKCE param", err)
return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err)
}
return csrfValue, nonceValue, pkceValue, nil
}
+// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param.
+type upstreamStateParamData struct {
+ AuthParams string `json:"p"`
+ Nonce nonce.Nonce `json:"n"`
+ CSRFToken csrftoken.CSRFToken `json:"c"`
+ PKCECode pkce.Code `json:"k"`
+ StateParamFormatVersion string `json:"v"`
+}
+
+func upstreamStateParam(
+ authorizeRequester fosite.AuthorizeRequester,
+ nonceValue nonce.Nonce,
+ csrfValue csrftoken.CSRFToken,
+ pkceValue pkce.Code,
+ encoder Encoder,
+) (string, error) {
+ stateParamData := upstreamStateParamData{
+ AuthParams: authorizeRequester.GetRequestForm().Encode(),
+ Nonce: nonceValue,
+ CSRFToken: csrfValue,
+ PKCECode: pkceValue,
+ StateParamFormatVersion: upstreamStateParamFormatVersion,
+ }
+ encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData)
+ if err != nil {
+ return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err)
+ }
+ return encodedStateParamValue, nil
+}
+
+func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken) {
+ http.SetCookie(w, &http.Cookie{
+ Name: csrfCookieName,
+ Value: string(csrfValue),
+ HttpOnly: true,
+ SameSite: http.SameSiteStrictMode,
+ Secure: true,
+ })
+}
+
func fositeErrorForLog(err error) []interface{} {
rfc6749Error := fosite.ErrorToRFC6749Error(err)
keysAndValues := make([]interface{}, 0)
diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go
index cc708f1a..46557937 100644
--- a/internal/oidc/auth/auth_handler_test.go
+++ b/internal/oidc/auth/auth_handler_test.go
@@ -28,7 +28,8 @@ import (
func TestAuthorizationEndpoint(t *testing.T) {
const (
- downstreamRedirectURI = "http://127.0.0.1/callback"
+ downstreamRedirectURI = "http://127.0.0.1/callback"
+ downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback"
)
var (
@@ -144,8 +145,8 @@ func TestAuthorizationEndpoint(t *testing.T) {
expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"
var encoderHashKey = []byte("fake-hash-secret")
- var encoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption
- encoder.SetSerializer(securecookie.JSONEncoder{})
+ var happyEncoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption
+ happyEncoder.SetSerializer(securecookie.JSONEncoder{})
encodeQuery := func(query map[string]string) string {
values := url.Values{}
@@ -168,22 +169,24 @@ func TestAuthorizationEndpoint(t *testing.T) {
return urlToReturn
}
- happyGetRequestQueryMap := map[string]string{
- "response_type": "code",
- "scope": "openid profile email",
- "client_id": "pinniped-cli",
- "state": "some-state-value",
- "nonce": "some-nonce-value",
- "code_challenge": "some-challenge",
- "code_challenge_method": "S256",
- "redirect_uri": downstreamRedirectURI,
+ happyGetRequestQueryMap := func(downstreamRedirectURI string) map[string]string {
+ return map[string]string{
+ "response_type": "code",
+ "scope": "openid profile email",
+ "client_id": "pinniped-cli",
+ "state": "some-state-value",
+ "nonce": "some-nonce-value",
+ "code_challenge": "some-challenge",
+ "code_challenge_method": "S256",
+ "redirect_uri": downstreamRedirectURI,
+ }
}
- happyGetRequestPath := pathWithQuery("/some/path", happyGetRequestQueryMap)
+ happyGetRequestPath := pathWithQuery("/some/path", happyGetRequestQueryMap(downstreamRedirectURI))
modifiedHappyGetRequestPath := func(queryOverrides map[string]string) string {
copyOfHappyGetRequestQueryMap := map[string]string{}
- for k, v := range happyGetRequestQueryMap {
+ for k, v := range happyGetRequestQueryMap(downstreamRedirectURI) {
copyOfHappyGetRequestQueryMap[k] = v
}
for k, v := range queryOverrides {
@@ -197,38 +200,33 @@ func TestAuthorizationEndpoint(t *testing.T) {
return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap)
}
- // We're going to use this value to make assertions, so specify the exact expected value.
- happyUpstreamStateParam, err := encoder.Encode("s",
- // Ensure that the order of the serialized fields is exactly this order, so we can make simpler equality assertions below.
- struct {
- P string `json:"p"`
- N string `json:"n"`
- C string `json:"c"`
- K string `json:"k"`
- V string `json:"v"`
- }{
- P: encodeQuery(happyGetRequestQueryMap),
- N: happyNonce,
- C: happyCSRF,
- K: happyPKCE,
- V: "1",
- },
- )
- require.NoError(t, err)
+ happyExpectedUpstreamStateParam := func(downstreamRedirectURI string) string {
+ encoded, err := happyEncoder.Encode("s",
+ expectedUpstreamStateParamFormat{
+ P: encodeQuery(happyGetRequestQueryMap(downstreamRedirectURI)),
+ N: happyNonce,
+ C: happyCSRF,
+ K: happyPKCE,
+ V: "1",
+ },
+ )
+ require.NoError(t, err)
+ return encoded
+ }
- happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(),
- map[string]string{
+ happyExpectedRedirectLocation := func(downstreamRedirectURI string) string {
+ return urlWithQuery(upstreamAuthURL.String(), map[string]string{
"response_type": "code",
"access_type": "offline",
"scope": "scope1 scope2",
"client_id": "some-client-id",
- "state": happyUpstreamStateParam,
+ "state": happyExpectedUpstreamStateParam(downstreamRedirectURI),
"nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-idp",
- },
- )
+ })
+ }
happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF)
@@ -240,6 +238,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF func() (csrftoken.CSRFToken, error)
generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error)
+ encoder securecookie.Codec
method string
path string
contentType string
@@ -251,8 +250,9 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantBodyJSON string
wantLocationHeader string
wantCSRFCookieHeader string
- }
+ wantUpstreamStateParamInLocationHeader bool
+ }
tests := []testCase{
{
name: "happy path using GET",
@@ -261,54 +261,59 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusFound,
wantContentType: "text/html; charset=utf-8",
wantBodyString: fmt.Sprintf(`Found.%s`,
- html.EscapeString(happyGetRequestExpectedRedirectLocation),
+ html.EscapeString(happyExpectedRedirectLocation(downstreamRedirectURI)),
"\n\n",
),
- wantLocationHeader: happyGetRequestExpectedRedirectLocation,
- wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
+ wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
+ wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURI),
+ wantUpstreamStateParamInLocationHeader: true,
},
{
- name: "happy path using POST",
+ name: "happy path using POST",
+ issuer: issuer,
+ idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
+ generateCSRF: happyCSRFGenerator,
+ generatePKCE: happyPKCEGenerator,
+ generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
+ method: http.MethodPost,
+ path: "/some/path",
+ contentType: "application/x-www-form-urlencoded",
+ body: encodeQuery(happyGetRequestQueryMap(downstreamRedirectURI)),
+ wantStatus: http.StatusFound,
+ wantContentType: "",
+ wantBodyString: "",
+ wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
+ wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURI),
+ wantUpstreamStateParamInLocationHeader: true,
+ },
+ {
+ name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
- method: http.MethodPost,
- path: "/some/path",
- contentType: "application/x-www-form-urlencoded",
- body: url.Values{
- "response_type": []string{"code"},
- "scope": []string{"openid profile email"},
- "client_id": []string{"pinniped-cli"},
- "state": []string{"some-state-value"},
- "code_challenge": []string{"some-challenge"},
- "code_challenge_method": []string{"S256"},
- "redirect_uri": []string{downstreamRedirectURI},
- }.Encode(),
- wantStatus: http.StatusFound,
- wantContentType: "",
- wantBodyString: "",
- wantLocationHeader: happyGetRequestExpectedRedirectLocation,
- wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
- },
- {
- name: "downstream client does not exist",
- issuer: issuer,
- idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
- generateCSRF: happyCSRFGenerator,
- generatePKCE: happyPKCEGenerator,
- generateNonce: happyNonceGenerator,
- method: http.MethodGet,
- path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}),
- wantStatus: http.StatusUnauthorized,
- wantContentType: "application/json; charset=utf-8",
- wantBodyJSON: fositeInvalidClientErrorBody,
+ encoder: happyEncoder,
+ method: http.MethodGet,
+ path: modifiedHappyGetRequestPath(map[string]string{
+ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
+ }),
+ wantStatus: http.StatusFound,
+ wantContentType: "text/html; charset=utf-8",
+ wantBodyString: fmt.Sprintf(`Found.%s`,
+ html.EscapeString(happyExpectedRedirectLocation(downstreamRedirectURIWithDifferentPort)),
+ "\n\n",
+ ),
+ wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
+ wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURIWithDifferentPort),
+ wantUpstreamStateParamInLocationHeader: true,
},
{
name: "downstream redirect uri does not match what is configured for client",
@@ -317,6 +322,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{
"redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client",
@@ -326,24 +332,18 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantBodyJSON: fositeInvalidRedirectURIErrorBody,
},
{
- name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
- issuer: issuer,
- idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
- generateCSRF: happyCSRFGenerator,
- generatePKCE: happyPKCEGenerator,
- generateNonce: happyNonceGenerator,
- method: http.MethodGet,
- path: modifiedHappyGetRequestPath(map[string]string{
- "redirect_uri": "http://127.0.0.1:42/callback",
- }),
- wantStatus: http.StatusFound,
- wantContentType: "text/html; charset=utf-8",
- wantBodyString: fmt.Sprintf(`Found.%s`,
- html.EscapeString(happyGetRequestExpectedRedirectLocation),
- "\n\n",
- ),
- wantLocationHeader: happyGetRequestExpectedRedirectLocation,
- wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
+ name: "downstream client does not exist",
+ issuer: issuer,
+ idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
+ generateCSRF: happyCSRFGenerator,
+ generatePKCE: happyPKCEGenerator,
+ generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
+ method: http.MethodGet,
+ path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}),
+ wantStatus: http.StatusUnauthorized,
+ wantContentType: "application/json; charset=utf-8",
+ wantBodyJSON: fositeInvalidClientErrorBody,
},
{
name: "response type is unsupported",
@@ -352,6 +352,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}),
wantStatus: http.StatusFound,
@@ -366,6 +367,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}),
wantStatus: http.StatusFound,
@@ -380,6 +382,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}),
wantStatus: http.StatusFound,
@@ -394,6 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}),
wantStatus: http.StatusUnauthorized,
@@ -407,6 +411,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}),
wantStatus: http.StatusFound,
@@ -421,6 +426,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}),
wantStatus: http.StatusFound,
@@ -435,6 +441,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}),
wantStatus: http.StatusFound,
@@ -449,6 +456,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}),
wantStatus: http.StatusFound,
@@ -465,6 +473,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}),
wantStatus: http.StatusFound,
@@ -479,6 +488,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}),
wantStatus: http.StatusFound,
@@ -486,6 +496,20 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantLocationHeader: urlWithQuery(downstreamRedirectURI, fositeInvalidStateErrorQuery),
wantBodyString: "",
},
+ {
+ name: "error while encoding upstream state param",
+ issuer: issuer,
+ idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
+ generateCSRF: happyCSRFGenerator,
+ generatePKCE: happyPKCEGenerator,
+ generateNonce: happyNonceGenerator,
+ encoder: &errorReturningEncoder{},
+ method: http.MethodGet,
+ path: happyGetRequestPath,
+ wantStatus: http.StatusInternalServerError,
+ wantContentType: "text/plain; charset=utf-8",
+ wantBodyString: "Internal Server Error: error encoding upstream state param\n",
+ },
{
name: "error while generating CSRF token",
issuer: issuer,
@@ -493,6 +517,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError,
@@ -506,6 +531,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
+ encoder: happyEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError,
@@ -519,6 +545,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusInternalServerError,
@@ -586,19 +613,22 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code)
requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
+ if test.wantLocationHeader != "" {
+ actualLocation := rsp.Header().Get("Location")
+ if test.wantUpstreamStateParamInLocationHeader {
+ requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.encoder)
+ }
+ requireEqualURLs(t, actualLocation, test.wantLocationHeader)
+ } else {
+ require.Empty(t, rsp.Header().Values("Location"))
+ }
+
if test.wantBodyJSON != "" {
require.JSONEq(t, test.wantBodyJSON, rsp.Body.String())
} else {
require.Equal(t, test.wantBodyString, rsp.Body.String())
}
- if test.wantLocationHeader != "" {
- actualLocation := rsp.Header().Get("Location")
- requireEqualURLs(t, actualLocation, test.wantLocationHeader)
- } else {
- require.Empty(t, rsp.Header().Values("Location"))
- }
-
if test.wantCSRFCookieHeader != "" {
require.Len(t, rsp.Header().Values("Set-Cookie"), 1)
actualCookie := rsp.Header().Get("Set-Cookie")
@@ -611,7 +641,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
- subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder)
+ subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder)
runOneTestCase(t, test, subject)
})
}
@@ -620,7 +650,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
test := tests[0]
require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case
- subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder)
+ subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder)
runOneTestCase(t, test, subject)
@@ -640,7 +670,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
"access_type": "offline",
"scope": "other-scope1 other-scope2",
"client_id": "some-other-client-id",
- "state": happyUpstreamStateParam,
+ "state": happyExpectedUpstreamStateParam(downstreamRedirectURI),
"nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256",
@@ -660,6 +690,26 @@ func TestAuthorizationEndpoint(t *testing.T) {
})
}
+// Declare a separate type from the production code to ensure that the state param's contents was serialized
+// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
+// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
+// assertions about the redirect URL in this test.
+type expectedUpstreamStateParamFormat struct {
+ P string `json:"p"`
+ N string `json:"n"`
+ C string `json:"c"`
+ K string `json:"k"`
+ V string `json:"v"`
+}
+
+type errorReturningEncoder struct {
+ securecookie.Codec
+}
+
+func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) {
+ return "", fmt.Errorf("some encoding error")
+}
+
func requireEqualContentType(t *testing.T, actual string, expected string) {
t.Helper()
@@ -676,6 +726,28 @@ func requireEqualContentType(t *testing.T, actual string, expected string) {
require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
}
+func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) {
+ t.Helper()
+ actualLocationURL, err := url.Parse(actualURL)
+ require.NoError(t, err)
+ expectedLocationURL, err := url.Parse(expectedURL)
+ require.NoError(t, err)
+
+ expectedQueryStateParam := expectedLocationURL.Query().Get("state")
+ require.NotEmpty(t, expectedQueryStateParam)
+ var expectedDecodedStateParam expectedUpstreamStateParamFormat
+ err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
+ require.NoError(t, err)
+
+ actualQueryStateParam := actualLocationURL.Query().Get("state")
+ require.NotEmpty(t, actualQueryStateParam)
+ var actualDecodedStateParam expectedUpstreamStateParamFormat
+ err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
+ require.NoError(t, err)
+
+ require.Equal(t, expectedDecodedStateParam, actualDecodedStateParam)
+}
+
func requireEqualURLs(t *testing.T, actualURL string, expectedURL string) {
t.Helper()
actualLocationURL, err := url.Parse(actualURL)
diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go
index 77112fca..1cb37fc8 100644
--- a/internal/oidc/provider/manager/manager.go
+++ b/internal/oidc/provider/manager/manager.go
@@ -9,13 +9,12 @@ import (
"sync"
"github.com/gorilla/securecookie"
- "go.pinniped.dev/internal/oidc/csrftoken"
-
"github.com/ory/fosite"
"github.com/ory/fosite/storage"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/auth"
+ "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/provider"
diff --git a/internal/plog/doc.go b/internal/plog/doc.go
deleted file mode 100644
index 41cc3071..00000000
--- a/internal/plog/doc.go
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2020 the Pinniped contributors. All Rights Reserved.
-// SPDX-License-Identifier: Apache-2.0
-
-// Package plog implements a thin layer over klog to help enforce pinniped's logging convention.
-// Logs are always structured as a constant message with key and value pairs of related metadata.
-// The logging levels in order of increasing verbosity are:
-// error, warning, info, debug, trace and all.
-// error and warning logs are always emitted (there is no way for the end user to disable them),
-// and thus should be used sparingly. Ideally, logs at these levels should be actionable.
-// info should be reserved for "nice to know" information. It should be possible to run a production
-// pinniped server at the info log level with no performance degradation due to high log volume.
-// debug should be used for information targeted at developers and to aid in support cases. Care must
-// be taken at this level to not leak any secrets into the log stream. That is, even though debug may
-// cause performance issues in production, it must not cause security issues in production.
-// trace should be used to log information related to timing (i.e. the time it took a controller to sync).
-// Just like debug, trace should not leak secrets into the log stream. trace will likely leak information
-// about the current state of the process, but that, along with performance degradation, is expected.
-// all is reserved for the most verbose and security sensitive information. At this level, full request
-// metadata such as headers and parameters along with the body may be logged. This level is completely
-// unfit for production use both from a performance and security standpoint. Using it is generally an
-// act of desperation to determine why the system is broken.
-package plog
diff --git a/internal/plog/plog.go b/internal/plog/plog.go
index d3b6efa8..ffcefdb8 100644
--- a/internal/plog/plog.go
+++ b/internal/plog/plog.go
@@ -1,12 +1,37 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
+// Package plog implements a thin layer over klog to help enforce pinniped's logging convention.
+// Logs are always structured as a constant message with key and value pairs of related metadata.
+//
+// The logging levels in order of increasing verbosity are:
+// error, warning, info, debug, trace and all.
+//
+// error and warning logs are always emitted (there is no way for the end user to disable them),
+// and thus should be used sparingly. Ideally, logs at these levels should be actionable.
+//
+// info should be reserved for "nice to know" information. It should be possible to run a production
+// pinniped server at the info log level with no performance degradation due to high log volume.
+// debug should be used for information targeted at developers and to aid in support cases. Care must
+// be taken at this level to not leak any secrets into the log stream. That is, even though debug may
+// cause performance issues in production, it must not cause security issues in production.
+//
+// trace should be used to log information related to timing (i.e. the time it took a controller to sync).
+// Just like debug, trace should not leak secrets into the log stream. trace will likely leak information
+// about the current state of the process, but that, along with performance degradation, is expected.
+//
+// all is reserved for the most verbose and security sensitive information. At this level, full request
+// metadata such as headers and parameters along with the body may be logged. This level is completely
+// unfit for production use both from a performance and security standpoint. Using it is generally an
+// act of desperation to determine why the system is broken.
package plog
import "k8s.io/klog/v2"
+const errorKey = "error"
+
// Use Error to log an unexpected system error.
-func Error(err error, msg string, keysAndValues ...interface{}) {
+func Error(msg string, err error, keysAndValues ...interface{}) {
klog.ErrorS(err, msg, keysAndValues...)
}
@@ -19,23 +44,38 @@ func Warning(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelWarning).InfoS(msg, keysAndValues...)
}
+// Use WarningErr to issue a Warning message with an error object as part of the message.
+func WarningErr(msg string, err error, keysAndValues ...interface{}) {
+ Warning(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
+}
+
func Info(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelInfo).InfoS(msg, keysAndValues...)
}
// Use InfoErr to log an expected error, e.g. validation failure of an http parameter.
func InfoErr(msg string, err error, keysAndValues ...interface{}) {
- klog.V(klogLevelInfo).InfoS(msg, append([]interface{}{"error", err}, keysAndValues)...)
+ Info(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
}
func Debug(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelDebug).InfoS(msg, keysAndValues...)
}
+// Use DebugErr to issue a Debug message with an error object as part of the message.
+func DebugErr(msg string, err error, keysAndValues ...interface{}) {
+ Debug(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
+}
+
func Trace(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelTrace).InfoS(msg, keysAndValues...)
}
+// Use TraceErr to issue a Trace message with an error object as part of the message.
+func TraceErr(msg string, err error, keysAndValues ...interface{}) {
+ Trace(msg, append([]interface{}{errorKey, err}, keysAndValues)...)
+}
+
func All(msg string, keysAndValues ...interface{}) {
klog.V(klogLevelAll).InfoS(msg, keysAndValues...)
}