diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index f562a393..b47d4984 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -26,13 +26,22 @@ const ( // Just in case we need to make a breaking change to the format of the upstream state param, // we are including a format version number. This gives the opportunity for a future version of Pinniped // to have the consumer of this format decide to reject versions that it doesn't understand. - stateParamFormatVersion = "1" + upstreamStateParamFormatVersion = "1" + + // The `name` passed to the encoder for encoding the upstream state param value. This name is short + // because it will be encoded into the upstream state param value and we're trying to keep that small. + upstreamStateParamEncodingName = "s" + + // The name of the browser cookie which shall hold our CSRF value. + // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes + csrfCookieName = "__Host-pinniped-csrf" ) type IDPListGetter interface { GetIDPList() []provider.UpstreamOIDCIdentityProvider } +// This is the encoding side of the securecookie.Codec interface. type Encoder interface { Encode(name string, value interface{}) (string, error) } @@ -63,7 +72,7 @@ func NewHandler( upstreamIDP, err := chooseUpstreamIDP(idpListGetter) if err != nil { - plog.InfoErr("authorize request error", err) + plog.WarningErr("authorize upstream config", err) return err } @@ -91,7 +100,7 @@ func NewHandler( csrfValue, nonceValue, pkceValue, err := generateValues(generateCSRF, generateNonce, generatePKCE) if err != nil { - plog.InfoErr("authorize generate error", err) + plog.Error("authorize generate error", err) return err } @@ -104,24 +113,13 @@ func NewHandler( Scopes: upstreamIDP.Scopes, } - // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes - http.SetCookie(w, &http.Cookie{ - Name: "__Host-pinniped-csrf", - Value: string(csrfValue), - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - Secure: true, - }) - - stateParamData := upstreamStateParamData{ - AuthParams: authorizeRequester.GetRequestForm().Encode(), - Nonce: nonceValue, - CSRFToken: csrfValue, - PKCECode: pkceValue, - StateParamFormatVersion: stateParamFormatVersion, + encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, encoder) + if err != nil { + plog.Error("authorize upstream state param error", err) + return err } - encodedStateParamValue, err := encoder.Encode("s", stateParamData) - // TODO handle the above error + + addCSRFSetCookieHeader(w, csrfValue) http.Redirect(w, r, upstreamOAuthConfig.AuthCodeURL( @@ -138,15 +136,6 @@ func NewHandler( }) } -// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param. -type upstreamStateParamData struct { - AuthParams string `json:"p"` - Nonce nonce.Nonce `json:"n"` - CSRFToken csrftoken.CSRFToken `json:"c"` - PKCECode pkce.Code `json:"k"` - StateParamFormatVersion string `json:"v"` -} - func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { allUpstreamIDPs := idpListGetter.GetIDPList() if len(allUpstreamIDPs) == 0 { @@ -170,22 +159,59 @@ func generateValues( ) (csrftoken.CSRFToken, nonce.Nonce, pkce.Code, error) { csrfValue, err := generateCSRF() if err != nil { - plog.InfoErr("error generating csrf param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating CSRF token", err) } nonceValue, err := generateNonce() if err != nil { - plog.InfoErr("error generating nonce param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating nonce param", err) } pkceValue, err := generatePKCE() if err != nil { - plog.InfoErr("error generating PKCE param", err) return "", "", "", httperr.Wrap(http.StatusInternalServerError, "error generating PKCE param", err) } return csrfValue, nonceValue, pkceValue, nil } +// Keep the JSON to a minimal size because the upstream provider could impose size limitations on the state param. +type upstreamStateParamData struct { + AuthParams string `json:"p"` + Nonce nonce.Nonce `json:"n"` + CSRFToken csrftoken.CSRFToken `json:"c"` + PKCECode pkce.Code `json:"k"` + StateParamFormatVersion string `json:"v"` +} + +func upstreamStateParam( + authorizeRequester fosite.AuthorizeRequester, + nonceValue nonce.Nonce, + csrfValue csrftoken.CSRFToken, + pkceValue pkce.Code, + encoder Encoder, +) (string, error) { + stateParamData := upstreamStateParamData{ + AuthParams: authorizeRequester.GetRequestForm().Encode(), + Nonce: nonceValue, + CSRFToken: csrfValue, + PKCECode: pkceValue, + StateParamFormatVersion: upstreamStateParamFormatVersion, + } + encodedStateParamValue, err := encoder.Encode(upstreamStateParamEncodingName, stateParamData) + if err != nil { + return "", httperr.Wrap(http.StatusInternalServerError, "error encoding upstream state param", err) + } + return encodedStateParamValue, nil +} + +func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken) { + http.SetCookie(w, &http.Cookie{ + Name: csrfCookieName, + Value: string(csrfValue), + HttpOnly: true, + SameSite: http.SameSiteStrictMode, + Secure: true, + }) +} + func fositeErrorForLog(err error) []interface{} { rfc6749Error := fosite.ErrorToRFC6749Error(err) keysAndValues := make([]interface{}, 0) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index cc708f1a..46557937 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -28,7 +28,8 @@ import ( func TestAuthorizationEndpoint(t *testing.T) { const ( - downstreamRedirectURI = "http://127.0.0.1/callback" + downstreamRedirectURI = "http://127.0.0.1/callback" + downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback" ) var ( @@ -144,8 +145,8 @@ func TestAuthorizationEndpoint(t *testing.T) { expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g" var encoderHashKey = []byte("fake-hash-secret") - var encoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption - encoder.SetSerializer(securecookie.JSONEncoder{}) + var happyEncoder = securecookie.New(encoderHashKey, nil) // note that nil block key argument turns off encryption + happyEncoder.SetSerializer(securecookie.JSONEncoder{}) encodeQuery := func(query map[string]string) string { values := url.Values{} @@ -168,22 +169,24 @@ func TestAuthorizationEndpoint(t *testing.T) { return urlToReturn } - happyGetRequestQueryMap := map[string]string{ - "response_type": "code", - "scope": "openid profile email", - "client_id": "pinniped-cli", - "state": "some-state-value", - "nonce": "some-nonce-value", - "code_challenge": "some-challenge", - "code_challenge_method": "S256", - "redirect_uri": downstreamRedirectURI, + happyGetRequestQueryMap := func(downstreamRedirectURI string) map[string]string { + return map[string]string{ + "response_type": "code", + "scope": "openid profile email", + "client_id": "pinniped-cli", + "state": "some-state-value", + "nonce": "some-nonce-value", + "code_challenge": "some-challenge", + "code_challenge_method": "S256", + "redirect_uri": downstreamRedirectURI, + } } - happyGetRequestPath := pathWithQuery("/some/path", happyGetRequestQueryMap) + happyGetRequestPath := pathWithQuery("/some/path", happyGetRequestQueryMap(downstreamRedirectURI)) modifiedHappyGetRequestPath := func(queryOverrides map[string]string) string { copyOfHappyGetRequestQueryMap := map[string]string{} - for k, v := range happyGetRequestQueryMap { + for k, v := range happyGetRequestQueryMap(downstreamRedirectURI) { copyOfHappyGetRequestQueryMap[k] = v } for k, v := range queryOverrides { @@ -197,38 +200,33 @@ func TestAuthorizationEndpoint(t *testing.T) { return pathWithQuery("/some/path", copyOfHappyGetRequestQueryMap) } - // We're going to use this value to make assertions, so specify the exact expected value. - happyUpstreamStateParam, err := encoder.Encode("s", - // Ensure that the order of the serialized fields is exactly this order, so we can make simpler equality assertions below. - struct { - P string `json:"p"` - N string `json:"n"` - C string `json:"c"` - K string `json:"k"` - V string `json:"v"` - }{ - P: encodeQuery(happyGetRequestQueryMap), - N: happyNonce, - C: happyCSRF, - K: happyPKCE, - V: "1", - }, - ) - require.NoError(t, err) + happyExpectedUpstreamStateParam := func(downstreamRedirectURI string) string { + encoded, err := happyEncoder.Encode("s", + expectedUpstreamStateParamFormat{ + P: encodeQuery(happyGetRequestQueryMap(downstreamRedirectURI)), + N: happyNonce, + C: happyCSRF, + K: happyPKCE, + V: "1", + }, + ) + require.NoError(t, err) + return encoded + } - happyGetRequestExpectedRedirectLocation := urlWithQuery(upstreamAuthURL.String(), - map[string]string{ + happyExpectedRedirectLocation := func(downstreamRedirectURI string) string { + return urlWithQuery(upstreamAuthURL.String(), map[string]string{ "response_type": "code", "access_type": "offline", "scope": "scope1 scope2", "client_id": "some-client-id", - "state": happyUpstreamStateParam, + "state": happyExpectedUpstreamStateParam(downstreamRedirectURI), "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", "redirect_uri": issuer + "/callback/some-idp", - }, - ) + }) + } happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF) @@ -240,6 +238,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF func() (csrftoken.CSRFToken, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) + encoder securecookie.Codec method string path string contentType string @@ -251,8 +250,9 @@ func TestAuthorizationEndpoint(t *testing.T) { wantBodyJSON string wantLocationHeader string wantCSRFCookieHeader string - } + wantUpstreamStateParamInLocationHeader bool + } tests := []testCase{ { name: "happy path using GET", @@ -261,54 +261,59 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", wantBodyString: fmt.Sprintf(`Found.%s`, - html.EscapeString(happyGetRequestExpectedRedirectLocation), + html.EscapeString(happyExpectedRedirectLocation(downstreamRedirectURI)), "\n\n", ), - wantLocationHeader: happyGetRequestExpectedRedirectLocation, - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, + wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, + wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURI), + wantUpstreamStateParamInLocationHeader: true, }, { - name: "happy path using POST", + name: "happy path using POST", + issuer: issuer, + idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + encoder: happyEncoder, + method: http.MethodPost, + path: "/some/path", + contentType: "application/x-www-form-urlencoded", + body: encodeQuery(happyGetRequestQueryMap(downstreamRedirectURI)), + wantStatus: http.StatusFound, + wantContentType: "", + wantBodyString: "", + wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, + wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURI), + wantUpstreamStateParamInLocationHeader: true, + }, + { + name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: issuer, idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - method: http.MethodPost, - path: "/some/path", - contentType: "application/x-www-form-urlencoded", - body: url.Values{ - "response_type": []string{"code"}, - "scope": []string{"openid profile email"}, - "client_id": []string{"pinniped-cli"}, - "state": []string{"some-state-value"}, - "code_challenge": []string{"some-challenge"}, - "code_challenge_method": []string{"S256"}, - "redirect_uri": []string{downstreamRedirectURI}, - }.Encode(), - wantStatus: http.StatusFound, - wantContentType: "", - wantBodyString: "", - wantLocationHeader: happyGetRequestExpectedRedirectLocation, - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, - }, - { - name: "downstream client does not exist", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), - generateCSRF: happyCSRFGenerator, - generatePKCE: happyPKCEGenerator, - generateNonce: happyNonceGenerator, - method: http.MethodGet, - path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}), - wantStatus: http.StatusUnauthorized, - wantContentType: "application/json; charset=utf-8", - wantBodyJSON: fositeInvalidClientErrorBody, + encoder: happyEncoder, + method: http.MethodGet, + path: modifiedHappyGetRequestPath(map[string]string{ + "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client + }), + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + wantBodyString: fmt.Sprintf(`Found.%s`, + html.EscapeString(happyExpectedRedirectLocation(downstreamRedirectURIWithDifferentPort)), + "\n\n", + ), + wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, + wantLocationHeader: happyExpectedRedirectLocation(downstreamRedirectURIWithDifferentPort), + wantUpstreamStateParamInLocationHeader: true, }, { name: "downstream redirect uri does not match what is configured for client", @@ -317,6 +322,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{ "redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client", @@ -326,24 +332,18 @@ func TestAuthorizationEndpoint(t *testing.T) { wantBodyJSON: fositeInvalidRedirectURIErrorBody, }, { - name: "happy path when downstream redirect uri matches what is configured for client except for the port number", - issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), - generateCSRF: happyCSRFGenerator, - generatePKCE: happyPKCEGenerator, - generateNonce: happyNonceGenerator, - method: http.MethodGet, - path: modifiedHappyGetRequestPath(map[string]string{ - "redirect_uri": "http://127.0.0.1:42/callback", - }), - wantStatus: http.StatusFound, - wantContentType: "text/html; charset=utf-8", - wantBodyString: fmt.Sprintf(`Found.%s`, - html.EscapeString(happyGetRequestExpectedRedirectLocation), - "\n\n", - ), - wantLocationHeader: happyGetRequestExpectedRedirectLocation, - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, + name: "downstream client does not exist", + issuer: issuer, + idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + encoder: happyEncoder, + method: http.MethodGet, + path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}), + wantStatus: http.StatusUnauthorized, + wantContentType: "application/json; charset=utf-8", + wantBodyJSON: fositeInvalidClientErrorBody, }, { name: "response type is unsupported", @@ -352,6 +352,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}), wantStatus: http.StatusFound, @@ -366,6 +367,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}), wantStatus: http.StatusFound, @@ -380,6 +382,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}), wantStatus: http.StatusFound, @@ -394,6 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}), wantStatus: http.StatusUnauthorized, @@ -407,6 +411,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}), wantStatus: http.StatusFound, @@ -421,6 +426,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}), wantStatus: http.StatusFound, @@ -435,6 +441,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}), wantStatus: http.StatusFound, @@ -449,6 +456,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}), wantStatus: http.StatusFound, @@ -465,6 +473,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}), wantStatus: http.StatusFound, @@ -479,6 +488,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}), wantStatus: http.StatusFound, @@ -486,6 +496,20 @@ func TestAuthorizationEndpoint(t *testing.T) { wantLocationHeader: urlWithQuery(downstreamRedirectURI, fositeInvalidStateErrorQuery), wantBodyString: "", }, + { + name: "error while encoding upstream state param", + issuer: issuer, + idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + encoder: &errorReturningEncoder{}, + method: http.MethodGet, + path: happyGetRequestPath, + wantStatus: http.StatusInternalServerError, + wantContentType: "text/plain; charset=utf-8", + wantBodyString: "Internal Server Error: error encoding upstream state param\n", + }, { name: "error while generating CSRF token", issuer: issuer, @@ -493,6 +517,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, @@ -506,6 +531,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, + encoder: happyEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, @@ -519,6 +545,7 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, + encoder: happyEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, @@ -586,19 +613,22 @@ func TestAuthorizationEndpoint(t *testing.T) { require.Equal(t, test.wantStatus, rsp.Code) requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) + if test.wantLocationHeader != "" { + actualLocation := rsp.Header().Get("Location") + if test.wantUpstreamStateParamInLocationHeader { + requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.encoder) + } + requireEqualURLs(t, actualLocation, test.wantLocationHeader) + } else { + require.Empty(t, rsp.Header().Values("Location")) + } + if test.wantBodyJSON != "" { require.JSONEq(t, test.wantBodyJSON, rsp.Body.String()) } else { require.Equal(t, test.wantBodyString, rsp.Body.String()) } - if test.wantLocationHeader != "" { - actualLocation := rsp.Header().Get("Location") - requireEqualURLs(t, actualLocation, test.wantLocationHeader) - } else { - require.Empty(t, rsp.Header().Values("Location")) - } - if test.wantCSRFCookieHeader != "" { require.Len(t, rsp.Header().Values("Set-Cookie"), 1) actualCookie := rsp.Header().Get("Set-Cookie") @@ -611,7 +641,7 @@ func TestAuthorizationEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder) + subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder) runOneTestCase(t, test, subject) }) } @@ -620,7 +650,7 @@ func TestAuthorizationEndpoint(t *testing.T) { test := tests[0] require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case - subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, encoder) + subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder) runOneTestCase(t, test, subject) @@ -640,7 +670,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "access_type": "offline", "scope": "other-scope1 other-scope2", "client_id": "some-other-client-id", - "state": happyUpstreamStateParam, + "state": happyExpectedUpstreamStateParam(downstreamRedirectURI), "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", @@ -660,6 +690,26 @@ func TestAuthorizationEndpoint(t *testing.T) { }) } +// Declare a separate type from the production code to ensure that the state param's contents was serialized +// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of +// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality +// assertions about the redirect URL in this test. +type expectedUpstreamStateParamFormat struct { + P string `json:"p"` + N string `json:"n"` + C string `json:"c"` + K string `json:"k"` + V string `json:"v"` +} + +type errorReturningEncoder struct { + securecookie.Codec +} + +func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) { + return "", fmt.Errorf("some encoding error") +} + func requireEqualContentType(t *testing.T, actual string, expected string) { t.Helper() @@ -676,6 +726,28 @@ func requireEqualContentType(t *testing.T, actual string, expected string) { require.Equal(t, actualContentTypeParams, expectedContentTypeParams) } +func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder securecookie.Codec) { + t.Helper() + actualLocationURL, err := url.Parse(actualURL) + require.NoError(t, err) + expectedLocationURL, err := url.Parse(expectedURL) + require.NoError(t, err) + + expectedQueryStateParam := expectedLocationURL.Query().Get("state") + require.NotEmpty(t, expectedQueryStateParam) + var expectedDecodedStateParam expectedUpstreamStateParamFormat + err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) + require.NoError(t, err) + + actualQueryStateParam := actualLocationURL.Query().Get("state") + require.NotEmpty(t, actualQueryStateParam) + var actualDecodedStateParam expectedUpstreamStateParamFormat + err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) + require.NoError(t, err) + + require.Equal(t, expectedDecodedStateParam, actualDecodedStateParam) +} + func requireEqualURLs(t *testing.T, actualURL string, expectedURL string) { t.Helper() actualLocationURL, err := url.Parse(actualURL) diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 77112fca..1cb37fc8 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -9,13 +9,12 @@ import ( "sync" "github.com/gorilla/securecookie" - "go.pinniped.dev/internal/oidc/csrftoken" - "github.com/ory/fosite" "github.com/ory/fosite/storage" "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/auth" + "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/provider" diff --git a/internal/plog/doc.go b/internal/plog/doc.go deleted file mode 100644 index 41cc3071..00000000 --- a/internal/plog/doc.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2020 the Pinniped contributors. All Rights Reserved. -// SPDX-License-Identifier: Apache-2.0 - -// Package plog implements a thin layer over klog to help enforce pinniped's logging convention. -// Logs are always structured as a constant message with key and value pairs of related metadata. -// The logging levels in order of increasing verbosity are: -// error, warning, info, debug, trace and all. -// error and warning logs are always emitted (there is no way for the end user to disable them), -// and thus should be used sparingly. Ideally, logs at these levels should be actionable. -// info should be reserved for "nice to know" information. It should be possible to run a production -// pinniped server at the info log level with no performance degradation due to high log volume. -// debug should be used for information targeted at developers and to aid in support cases. Care must -// be taken at this level to not leak any secrets into the log stream. That is, even though debug may -// cause performance issues in production, it must not cause security issues in production. -// trace should be used to log information related to timing (i.e. the time it took a controller to sync). -// Just like debug, trace should not leak secrets into the log stream. trace will likely leak information -// about the current state of the process, but that, along with performance degradation, is expected. -// all is reserved for the most verbose and security sensitive information. At this level, full request -// metadata such as headers and parameters along with the body may be logged. This level is completely -// unfit for production use both from a performance and security standpoint. Using it is generally an -// act of desperation to determine why the system is broken. -package plog diff --git a/internal/plog/plog.go b/internal/plog/plog.go index d3b6efa8..ffcefdb8 100644 --- a/internal/plog/plog.go +++ b/internal/plog/plog.go @@ -1,12 +1,37 @@ // Copyright 2020 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +// Package plog implements a thin layer over klog to help enforce pinniped's logging convention. +// Logs are always structured as a constant message with key and value pairs of related metadata. +// +// The logging levels in order of increasing verbosity are: +// error, warning, info, debug, trace and all. +// +// error and warning logs are always emitted (there is no way for the end user to disable them), +// and thus should be used sparingly. Ideally, logs at these levels should be actionable. +// +// info should be reserved for "nice to know" information. It should be possible to run a production +// pinniped server at the info log level with no performance degradation due to high log volume. +// debug should be used for information targeted at developers and to aid in support cases. Care must +// be taken at this level to not leak any secrets into the log stream. That is, even though debug may +// cause performance issues in production, it must not cause security issues in production. +// +// trace should be used to log information related to timing (i.e. the time it took a controller to sync). +// Just like debug, trace should not leak secrets into the log stream. trace will likely leak information +// about the current state of the process, but that, along with performance degradation, is expected. +// +// all is reserved for the most verbose and security sensitive information. At this level, full request +// metadata such as headers and parameters along with the body may be logged. This level is completely +// unfit for production use both from a performance and security standpoint. Using it is generally an +// act of desperation to determine why the system is broken. package plog import "k8s.io/klog/v2" +const errorKey = "error" + // Use Error to log an unexpected system error. -func Error(err error, msg string, keysAndValues ...interface{}) { +func Error(msg string, err error, keysAndValues ...interface{}) { klog.ErrorS(err, msg, keysAndValues...) } @@ -19,23 +44,38 @@ func Warning(msg string, keysAndValues ...interface{}) { klog.V(klogLevelWarning).InfoS(msg, keysAndValues...) } +// Use WarningErr to issue a Warning message with an error object as part of the message. +func WarningErr(msg string, err error, keysAndValues ...interface{}) { + Warning(msg, append([]interface{}{errorKey, err}, keysAndValues)...) +} + func Info(msg string, keysAndValues ...interface{}) { klog.V(klogLevelInfo).InfoS(msg, keysAndValues...) } // Use InfoErr to log an expected error, e.g. validation failure of an http parameter. func InfoErr(msg string, err error, keysAndValues ...interface{}) { - klog.V(klogLevelInfo).InfoS(msg, append([]interface{}{"error", err}, keysAndValues)...) + Info(msg, append([]interface{}{errorKey, err}, keysAndValues)...) } func Debug(msg string, keysAndValues ...interface{}) { klog.V(klogLevelDebug).InfoS(msg, keysAndValues...) } +// Use DebugErr to issue a Debug message with an error object as part of the message. +func DebugErr(msg string, err error, keysAndValues ...interface{}) { + Debug(msg, append([]interface{}{errorKey, err}, keysAndValues)...) +} + func Trace(msg string, keysAndValues ...interface{}) { klog.V(klogLevelTrace).InfoS(msg, keysAndValues...) } +// Use TraceErr to issue a Trace message with an error object as part of the message. +func TraceErr(msg string, err error, keysAndValues ...interface{}) { + Trace(msg, append([]interface{}{errorKey, err}, keysAndValues)...) +} + func All(msg string, keysAndValues ...interface{}) { klog.V(klogLevelAll).InfoS(msg, keysAndValues...) }