diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index b47d4984..de3c7f71 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "github.com/gorilla/securecookie" + "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" @@ -35,6 +37,9 @@ const ( // The name of the browser cookie which shall hold our CSRF value. // `__Host` prefix has a special meaning. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#Cookie_prefixes csrfCookieName = "__Host-pinniped-csrf" + + // The `name` passed to the encoder for encoding and decoding the CSRF cookie contents. + csrfCookieEncodingName = "csrf" ) type IDPListGetter interface { @@ -53,7 +58,8 @@ func NewHandler( generateCSRF func() (csrftoken.CSRFToken, error), generatePKCE func() (pkce.Code, error), generateNonce func() (nonce.Nonce, error), - encoder Encoder, + upstreamStateEncoder Encoder, + cookieCodec securecookie.Codec, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodPost && r.Method != http.MethodGet { @@ -63,6 +69,12 @@ func NewHandler( return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method) } + csrfFromCookie, err := readCSRFCookie(r, cookieCodec) + if err != nil { + plog.InfoErr("error reading CSRF cookie", err) + return err + } + authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r) if err != nil { plog.Info("authorize request error", fositeErrorForLog(err)...) @@ -77,11 +89,7 @@ func NewHandler( } // Grant the openid scope (for now) if they asked for it so that `NewAuthorizeResponse` will perform its OIDC validations. - for _, scope := range authorizeRequester.GetRequestedScopes() { - if scope == "openid" { - authorizeRequester.GrantScope(scope) - } - } + grantOpenIDScopeIfRequested(authorizeRequester) now := time.Now() _, err = oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ @@ -103,6 +111,9 @@ func NewHandler( plog.Error("authorize generate error", err) return err } + if csrfFromCookie != "" { + csrfValue = csrfFromCookie + } upstreamOAuthConfig := oauth2.Config{ ClientID: upstreamIDP.ClientID, @@ -113,13 +124,20 @@ func NewHandler( Scopes: upstreamIDP.Scopes, } - encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, encoder) + encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) if err != nil { plog.Error("authorize upstream state param error", err) return err } - addCSRFSetCookieHeader(w, csrfValue) + if csrfFromCookie == "" { + // We did not receive an incoming CSRF cookie, so write a new one. + err := addCSRFSetCookieHeader(w, csrfValue, cookieCodec) + if err != nil { + plog.Error("error setting CSRF cookie", err) + return err + } + } http.Redirect(w, r, upstreamOAuthConfig.AuthCodeURL( @@ -136,6 +154,30 @@ func NewHandler( }) } +func readCSRFCookie(r *http.Request, codec securecookie.Codec) (csrftoken.CSRFToken, error) { + receivedCSRFCookie, err := r.Cookie(csrfCookieName) + if err != nil { + // Error means that the cookie was not found + return "", nil + } + + var csrfFromCookie csrftoken.CSRFToken + err = codec.Decode(csrfCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie) + if err != nil { + return "", httperr.Wrap(http.StatusUnprocessableEntity, "error reading CSRF cookie", err) + } + + return csrfFromCookie, nil +} + +func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { + for _, scope := range authorizeRequester.GetRequestedScopes() { + if scope == "openid" { + authorizeRequester.GrantScope(scope) + } + } +} + func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { allUpstreamIDPs := idpListGetter.GetIDPList() if len(allUpstreamIDPs) == 0 { @@ -202,14 +244,21 @@ func upstreamStateParam( return encodedStateParamValue, nil } -func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken) { +func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken, codec securecookie.Codec) error { + encodedCSRFValue, err := codec.Encode(csrfCookieEncodingName, csrfValue) + if err != nil { + return httperr.Wrap(http.StatusInternalServerError, "error encoding CSRF cookie", err) + } + http.SetCookie(w, &http.Cookie{ Name: csrfCookieName, - Value: string(csrfValue), + Value: encodedCSRFValue, HttpOnly: true, SameSite: http.SameSiteStrictMode, Secure: true, }) + + return nil } func fositeErrorForLog(err error) []interface{} { diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 2a38a221..e0fdbacd 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "strings" "testing" @@ -137,10 +138,17 @@ func TestAuthorizationEndpoint(t *testing.T) { // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 expectedUpstreamCodeChallenge := "VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g" - var encoderHashKey = []byte("fake-hash-secret") - var encoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES - var happyEncoder = securecookie.New(encoderHashKey, encoderBlockKey) // note that nil block key argument turns off encryption - happyEncoder.SetSerializer(securecookie.JSONEncoder{}) + var stateEncoderHashKey = []byte("fake-hash-secret") + var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + var cookieEncoderHashKey = []byte("fake-hash-secret2") + var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateEncoder.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{}) encodeQuery := func(query map[string]string) string { values := url.Values{} @@ -196,12 +204,16 @@ func TestAuthorizationEndpoint(t *testing.T) { return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides)) } - expectedUpstreamStateParam := func(queryOverrides map[string]string) string { - encoded, err := happyEncoder.Encode("s", + expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string { + csrf := happyCSRF + if csrfValueOverride != "" { + csrf = csrfValueOverride + } + encoded, err := happyStateEncoder.Encode("s", expectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), N: happyNonce, - C: happyCSRF, + C: csrf, K: happyPKCE, V: "1", }, @@ -224,7 +236,9 @@ func TestAuthorizationEndpoint(t *testing.T) { }) } - happyCSRFSetCookieHeaderValue := fmt.Sprintf("__Host-pinniped-csrf=%s; HttpOnly; Secure; SameSite=Strict", happyCSRF) + incomingCookieCSRFValue := "csrf-value-from-cookie" + encodedIncomingCookieCSRFValue, err := happyCookieEncoder.Encode("csrf", incomingCookieCSRFValue) + require.NoError(t, err) type testCase struct { name string @@ -234,37 +248,58 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF func() (csrftoken.CSRFToken, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) - encoder securecookie.Codec + stateEncoder securecookie.Codec + cookieEncoder securecookie.Codec method string path string contentType string body string + csrfCookie string - wantStatus int - wantContentType string - wantBodyString string - wantBodyJSON string - wantLocationHeader string - wantCSRFCookieHeader string + wantStatus int + wantContentType string + wantBodyString string + wantBodyJSON string + wantLocationHeader string + wantCSRFValueInCookieHeader string wantUpstreamStateParamInLocationHeader bool wantBodyStringWithLocationInHref bool } tests := []testCase{ { - name: "happy path using GET", + name: "happy path using GET without a CSRF cookie", issuer: issuer, idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil)), + wantCSRFValueInCookieHeader: happyCSRF, + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantUpstreamStateParamInLocationHeader: true, + wantBodyStringWithLocationInHref: true, + }, + { + name: "happy path using GET with a CSRF cookie", + issuer: issuer, + idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, + method: http.MethodGet, + path: happyGetRequestPath, + csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue, + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -275,7 +310,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodPost, path: "/some/path", contentType: "application/x-www-form-urlencoded", @@ -283,8 +319,8 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: "", wantBodyString: "", - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil)), + wantCSRFValueInCookieHeader: happyCSRF, + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), wantUpstreamStateParamInLocationHeader: true, }, { @@ -294,17 +330,18 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client }), - wantStatus: http.StatusFound, - wantContentType: "text/html; charset=utf-8", - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client - })), + }, "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -315,7 +352,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{ "redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client", @@ -331,7 +369,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}), wantStatus: http.StatusUnauthorized, @@ -345,7 +384,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}), wantStatus: http.StatusFound, @@ -360,7 +400,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid profile email tuna"}), wantStatus: http.StatusFound, @@ -375,7 +416,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}), wantStatus: http.StatusFound, @@ -390,7 +432,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}), wantStatus: http.StatusUnauthorized, @@ -404,7 +447,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}), wantStatus: http.StatusFound, @@ -419,7 +463,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}), wantStatus: http.StatusFound, @@ -434,7 +479,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}), wantStatus: http.StatusFound, @@ -449,7 +495,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}), wantStatus: http.StatusFound, @@ -466,7 +513,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}), wantStatus: http.StatusFound, @@ -481,14 +529,17 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, // The following prompt value is illegal when openid is requested, but note that openid is not requested. - path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}), - wantStatus: http.StatusFound, - wantContentType: "text/html; charset=utf-8", - wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{"prompt": "none login", "scope": "email"})), + path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}), + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + wantCSRFValueInCookieHeader: happyCSRF, + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam( + map[string]string{"prompt": "none login", "scope": "email"}, "", + )), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -499,7 +550,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}), wantStatus: http.StatusFound, @@ -514,13 +566,29 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: &errorReturningEncoder{}, + stateEncoder: &errorReturningEncoder{}, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, wantContentType: "text/plain; charset=utf-8", wantBodyString: "Internal Server Error: error encoding upstream state param\n", }, + { + name: "error while encoding CSRF cookie value for new cookie", + issuer: issuer, + idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + stateEncoder: happyStateEncoder, + cookieEncoder: &errorReturningEncoder{}, + method: http.MethodGet, + path: happyGetRequestPath, + wantStatus: http.StatusInternalServerError, + wantContentType: "text/plain; charset=utf-8", + wantBodyString: "Internal Server Error: error encoding CSRF cookie\n", + }, { name: "error while generating CSRF token", issuer: issuer, @@ -528,7 +596,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, @@ -542,7 +611,8 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, @@ -556,13 +626,30 @@ func TestAuthorizationEndpoint(t *testing.T) { generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, - encoder: happyEncoder, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusInternalServerError, wantContentType: "text/plain; charset=utf-8", wantBodyString: "Internal Server Error: error generating PKCE param\n", }, + { + name: "error while decoding CSRF cookie", + issuer: issuer, + idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, + method: http.MethodGet, + path: happyGetRequestPath, + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusUnprocessableEntity, + wantContentType: "text/plain; charset=utf-8", + wantBodyString: "Unprocessable Entity: error reading CSRF cookie\n", + }, { name: "no upstream providers are configured", issuer: issuer, @@ -618,6 +705,9 @@ func TestAuthorizationEndpoint(t *testing.T) { runOneTestCase := func(t *testing.T, test testCase, subject http.Handler) { req := httptest.NewRequest(test.method, test.path, strings.NewReader(test.body)) req.Header.Set("Content-Type", test.contentType) + if test.csrfCookie != "" { + req.Header.Set("Cookie", test.csrfCookie) + } rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) @@ -627,7 +717,7 @@ func TestAuthorizationEndpoint(t *testing.T) { actualLocation := rsp.Header().Get("Location") if test.wantLocationHeader != "" { if test.wantUpstreamStateParamInLocationHeader { - requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.encoder) + requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.stateEncoder) } // The upstream state param is encoded using a timestamp at the beginning so we don't want to // compare those states since they may be different, but we do want to compare the downstream @@ -647,10 +737,17 @@ func TestAuthorizationEndpoint(t *testing.T) { require.Equal(t, test.wantBodyString, rsp.Body.String()) } - if test.wantCSRFCookieHeader != "" { + if test.wantCSRFValueInCookieHeader != "" { require.Len(t, rsp.Header().Values("Set-Cookie"), 1) actualCookie := rsp.Header().Get("Set-Cookie") - require.Equal(t, actualCookie, test.wantCSRFCookieHeader) + regex := regexp.MustCompile("__Host-pinniped-csrf=([^;]+); HttpOnly; Secure; SameSite=Strict") + submatches := regex.FindStringSubmatch(actualCookie) + require.Len(t, submatches, 2) + captured := submatches[1] + var decodedCSRFCookieValue string + err := test.cookieEncoder.Decode("csrf", captured, &decodedCSRFCookieValue) + require.NoError(t, err) + require.Equal(t, test.wantCSRFValueInCookieHeader, decodedCSRFCookieValue) } else { require.Empty(t, rsp.Header().Values("Set-Cookie")) } @@ -659,16 +756,16 @@ func TestAuthorizationEndpoint(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder) + subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder) runOneTestCase(t, test, subject) }) } t.Run("allows upstream provider configuration to change between requests", func(t *testing.T) { test := tests[0] - require.Equal(t, "happy path using GET", test.name) // re-use the happy path test case + require.Equal(t, "happy path using GET without a CSRF cookie", test.name) // re-use the happy path test case - subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.encoder) + subject := NewHandler(test.issuer, test.idpListGetter, oauthHelper, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder) runOneTestCase(t, test, subject) @@ -688,7 +785,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "access_type": "offline", "scope": "other-scope1 other-scope2", "client_id": "some-other-client-id", - "state": expectedUpstreamStateParam(nil), + "state": expectedUpstreamStateParam(nil, ""), "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 6badff09..0687ba22 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -72,13 +72,17 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { // the upstream callback endpoint is called later. oauthHelper := oidc.FositeOauth2Helper(oidc.NullStorage{}, []byte("some secret - must have at least 32 bytes")) // TODO replace this secret + // TODO use different codecs for the state and the cookie, because: + // 1. we would like to state to have an embedded expiration date while the cookie does not need that + // 2. we would like each downstream provider to use different secrets for signing/encrypting the upstream state, not share secrets + // 3. we would like *all* downstream providers to use the *same* signing key for the CSRF cookie (which doesn't need to be encrypted) because cookies are sent per-domain and our issuers can share a domain name (but have different paths) var encoderHashKey = []byte("fake-hash-secret") // TODO replace this secret var encoderBlockKey = []byte("16-bytes-aaaaaaa") // TODO replace this secret var encoder = securecookie.New(encoderHashKey, encoderBlockKey) encoder.SetSerializer(securecookie.JSONEncoder{}) authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath - m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder) + m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) }