diff --git a/internal/federationdomain/endpointsmanager/manager_test.go b/internal/federationdomain/endpointsmanager/manager_test.go index de6ea70b..819aa361 100644 --- a/internal/federationdomain/endpointsmanager/manager_test.go +++ b/internal/federationdomain/endpointsmanager/manager_test.go @@ -50,9 +50,14 @@ func TestManager(t *testing.T) { issuer2 = "https://example.com/some/path/more/deeply/nested/path" // note that this is a sub-path of the other issuer url issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2KeyID = "issuer2-key" - upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" - upstreamIDPName = "test-idp" - upstreamResourceUID = "test-resource-uid" + upstreamIDPAuthorizationURL1 = "https://test-upstream.com/auth1" + upstreamIDPAuthorizationURL2 = "https://test-upstream.com/auth2" + upstreamIDPDisplayName1 = "test-idp-display-name-1" + upstreamIDPDisplayName2 = "test-idp-display-name-2" + upstreamIDPName1 = "test-idp-1" + upstreamIDPName2 = "test-idp-2" + upstreamResourceUID1 = "test-resource-uid-1" + upstreamResourceUID2 = "test-resource-uid-2" upstreamIDPType = "oidc" downstreamClientID = "pinniped-cli" downstreamRedirectURL = "http://127.0.0.1:12345/callback" @@ -82,7 +87,7 @@ func TestManager(t *testing.T) { r.False(fallbackHandlerWasCalled) // Minimal check to ensure that the right discovery endpoint was called - r.Equal(http.StatusOK, recorder.Code) + r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder) responseBody, err := io.ReadAll(recorder.Body) r.NoError(err) parsedDiscoveryResult := discovery.Metadata{} @@ -92,7 +97,7 @@ func TestManager(t *testing.T) { r.Equal(parsedDiscoveryResult.SupervisorDiscovery.PinnipedIDPsEndpoint, expectedIssuer+oidc.PinnipedIDPsPathV1Alpha1) } - requirePinnipedIDPsDiscoveryRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedIDPName, expectedIDPType string, expectedFlows []string) { + requirePinnipedIDPsDiscoveryRequestToBeHandled := func(requestIssuer, requestURLSuffix string, expectedIDPNames []string, expectedIDPTypes string, expectedFlows []string) { recorder := httptest.NewRecorder() subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.PinnipedIDPsPathV1Alpha1+requestURLSuffix)) @@ -102,12 +107,18 @@ func TestManager(t *testing.T) { expectedFlowsJSON, err := json.Marshal(expectedFlows) require.NoError(t, err) + expectedIDPJSONList := []string{} + for i := range expectedIDPNames { + expectedIDPJSONList = append(expectedIDPJSONList, fmt.Sprintf(`{"name":"%s","type":"%s","flows":%s}`, + expectedIDPNames[i], expectedIDPTypes, expectedFlowsJSON)) + } + // Minimal check to ensure that the right IDP discovery endpoint was called - r.Equal(http.StatusOK, recorder.Code) + r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder) responseBody, err := io.ReadAll(recorder.Body) r.NoError(err) r.Equal( - fmt.Sprintf(`{"pinniped_identity_providers":[{"name":"%s","type":"%s","flows":%s}]}`+"\n", expectedIDPName, expectedIDPType, expectedFlowsJSON), + fmt.Sprintf(`{"pinniped_identity_providers":[%s]}`+"\n", strings.Join(expectedIDPJSONList, ",")), string(responseBody), ) } @@ -120,7 +131,7 @@ func TestManager(t *testing.T) { r.False(fallbackHandlerWasCalled) // Minimal check to ensure that the right endpoint was called - r.Equal(http.StatusSeeOther, recorder.Code) + r.Equal(http.StatusSeeOther, recorder.Code, "unexpected response:", recorder) actualLocation := recorder.Header().Get("Location") r.True( strings.HasPrefix(actualLocation, expectedRedirectLocationPrefix), @@ -159,7 +170,7 @@ func TestManager(t *testing.T) { // Check just enough of the response to ensure that we wired up the callback endpoint correctly. // The endpoint's own unit tests cover everything else. - r.Equal(http.StatusSeeOther, recorder.Code) + r.Equal(http.StatusSeeOther, recorder.Code, "unexpected response:", recorder) actualLocation := recorder.Header().Get("Location") r.True( strings.HasPrefix(actualLocation, downstreamRedirectURL), @@ -171,7 +182,7 @@ func TestManager(t *testing.T) { actualLocationQueryParams := parsedLocation.Query() r.Contains(actualLocationQueryParams, "code") r.Equal("openid username groups", actualLocationQueryParams.Get("scope")) - r.Equal("some-state-value-with-enough-bytes-to-exceed-min-allowed", actualLocationQueryParams.Get("state")) + r.Equal("some-state", actualLocationQueryParams.Get("state")) // Make sure that we wired up the callback endpoint to use kube storage for fosite sessions. r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3, @@ -198,8 +209,8 @@ func TestManager(t *testing.T) { r.False(fallbackHandlerWasCalled) // Minimal check to ensure that the right endpoint was called + r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder) var body map[string]interface{} - r.Equal(http.StatusOK, recorder.Code) r.NoError(json.Unmarshal(recorder.Body.Bytes(), &body)) r.Contains(body, "id_token") r.Contains(body, "access_token") @@ -228,7 +239,7 @@ func TestManager(t *testing.T) { r.False(fallbackHandlerWasCalled) // Minimal check to ensure that the right JWKS endpoint was called - r.Equal(http.StatusOK, recorder.Code) + r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder) responseBody, err := io.ReadAll(recorder.Body) r.NoError(err) parsedJWKSResult := jose.JSONWebKeySet{} @@ -246,31 +257,53 @@ func TestManager(t *testing.T) { } dynamicJWKSProvider = jwks.NewDynamicJWKSProvider() - parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) + parsedUpstreamIDPAuthorizationURL1, err := url.Parse(upstreamIDPAuthorizationURL1) + r.NoError(err) + parsedUpstreamIDPAuthorizationURL2, err := url.Parse(upstreamIDPAuthorizationURL2) r.NoError(err) federationDomainIDPs = []*federationdomainproviders.FederationDomainIdentityProvider{ { - DisplayName: upstreamIDPName, - UID: upstreamResourceUID, + DisplayName: upstreamIDPDisplayName1, + UID: upstreamResourceUID1, + Transforms: idtransform.NewTransformationPipeline(), + }, + { + DisplayName: upstreamIDPDisplayName2, + UID: upstreamResourceUID2, Transforms: idtransform.NewTransformationPipeline(), }, } - idpLister := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). - WithName(upstreamIDPName). - WithClientID("test-client-id"). - WithResourceUID(upstreamResourceUID). - WithAuthorizationURL(*parsedUpstreamIDPAuthorizationURL). - WithScopes([]string{"test-scope"}). - WithIDTokenClaim("iss", "https://some-issuer.com"). - WithIDTokenClaim("sub", "some-subject"). - WithIDTokenClaim("username", "test-username"). - WithIDTokenClaim("groups", "test-group1"). - WithRefreshToken("some-opaque-token"). - WithoutAccessToken(). - Build(), - ).BuildDynamicUpstreamIDPProvider() + idpLister := oidctestutil.NewUpstreamIDPListerBuilder(). + WithOIDC(oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName(upstreamIDPName1). + WithClientID("test-client-id-1"). + WithResourceUID(upstreamResourceUID1). + WithAuthorizationURL(*parsedUpstreamIDPAuthorizationURL1). + WithScopes([]string{"test-scope"}). + WithIDTokenClaim("iss", "https://some-issuer.com"). + WithIDTokenClaim("sub", "some-subject"). + WithIDTokenClaim("username", "test-username"). + WithIDTokenClaim("groups", "test-group1"). + WithRefreshToken("some-opaque-token"). + WithoutAccessToken(). + Build(), + ). + WithOIDC(oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName(upstreamIDPName2). + WithClientID("test-client-id-2"). + WithResourceUID(upstreamResourceUID2). + WithAuthorizationURL(*parsedUpstreamIDPAuthorizationURL2). + WithScopes([]string{"test-scope"}). + WithIDTokenClaim("iss", "https://some-issuer.com"). + WithIDTokenClaim("sub", "some-subject"). + WithIDTokenClaim("username", "test-username"). + WithIDTokenClaim("groups", "test-group1"). + WithRefreshToken("some-opaque-token"). + WithoutAccessToken(). + Build(), + ).BuildDynamicUpstreamIDPProvider() kubeClient = fake.NewSimpleClientset() secretsClient := kubeClient.CoreV1().Secrets("some-namespace") @@ -326,14 +359,14 @@ func TestManager(t *testing.T) { requireDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2) requireDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2) - requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) - requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) - requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "?some=query", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) + requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows) + requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows) + requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "?some=query", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows) // Hostnames are case-insensitive, so test that we can handle that. - requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1DifferentCaseHostname, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) - requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) - requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) + requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1DifferentCaseHostname, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows) + requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows) + requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows) issuer1JWKS := requireJWKSRequestToBeHandled(issuer1, "", issuer1KeyID) issuer2JWKS := requireJWKSRequestToBeHandled(issuer2, "", issuer2KeyID) @@ -344,35 +377,50 @@ func TestManager(t *testing.T) { requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) - authRequestParams := "?" + url.Values{ - "pinniped_idp_name": []string{upstreamIDPName}, + authRequestParamsIDP1 := "?" + url.Values{ + "pinniped_idp_name": []string{upstreamIDPDisplayName1}, "response_type": []string{"code"}, "scope": []string{"openid profile email username groups"}, "client_id": []string{downstreamClientID}, - "state": []string{"some-state-value-with-enough-bytes-to-exceed-min-allowed"}, + "state": []string{"some-state"}, "nonce": []string{"some-nonce-value-with-enough-bytes-to-exceed-min-allowed"}, "code_challenge": []string{testutil.SHA256(downstreamPKCECodeVerifier)}, "code_challenge_method": []string{"S256"}, "redirect_uri": []string{downstreamRedirectURL}, }.Encode() - requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) - requireAuthorizationRequestToBeHandled(issuer2, authRequestParams, upstreamIDPAuthorizationURL) + authRequestParamsIDP2 := "?" + url.Values{ + "pinniped_idp_name": []string{upstreamIDPDisplayName2}, + "response_type": []string{"code"}, + "scope": []string{"openid profile email username groups"}, + "client_id": []string{downstreamClientID}, + "state": []string{"some-state"}, + "nonce": []string{"some-nonce-value-with-enough-bytes-to-exceed-min-allowed"}, + "code_challenge": []string{testutil.SHA256(downstreamPKCECodeVerifier)}, + "code_challenge_method": []string{"S256"}, + "redirect_uri": []string{downstreamRedirectURL}, + }.Encode() + + requireAuthorizationRequestToBeHandled(issuer1, authRequestParamsIDP1, upstreamIDPAuthorizationURL1) + requireAuthorizationRequestToBeHandled(issuer2, authRequestParamsIDP1, upstreamIDPAuthorizationURL1) + requireAuthorizationRequestToBeHandled(issuer1, authRequestParamsIDP2, upstreamIDPAuthorizationURL2) + requireAuthorizationRequestToBeHandled(issuer2, authRequestParamsIDP2, upstreamIDPAuthorizationURL2) // Hostnames are case-insensitive, so test that we can handle that. csrfCookieValue1, upstreamStateParam1 := - requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParamsIDP1, upstreamIDPAuthorizationURL1) csrfCookieValue2, upstreamStateParam2 := - requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParamsIDP1, upstreamIDPAuthorizationURL1) - callbackRequestParams1 := "?" + url.Values{ - "code": []string{"some-fake-code"}, - "state": []string{upstreamStateParam1}, - }.Encode() - callbackRequestParams2 := "?" + url.Values{ - "code": []string{"some-fake-code"}, - "state": []string{upstreamStateParam2}, - }.Encode() + csrfCookieValue3, upstreamStateParam3 := + requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParamsIDP2, upstreamIDPAuthorizationURL2) + csrfCookieValue4, upstreamStateParam4 := + requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParamsIDP2, upstreamIDPAuthorizationURL2) + + callbackRequestParams1 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam1}}.Encode() + callbackRequestParams2 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam2}}.Encode() + callbackRequestParams3 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam3}}.Encode() + callbackRequestParams4 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam4}}.Encode() downstreamAuthCode1 := requireCallbackRequestToBeHandled(issuer1, callbackRequestParams1, csrfCookieValue1) downstreamAuthCode2 := requireCallbackRequestToBeHandled(issuer2, callbackRequestParams2, csrfCookieValue2) @@ -381,12 +429,17 @@ func TestManager(t *testing.T) { downstreamAuthCode3 := requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams1, csrfCookieValue1) downstreamAuthCode4 := requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams2, csrfCookieValue2) + downstreamAuthCode5 := requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams3, csrfCookieValue3) + downstreamAuthCode6 := requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams4, csrfCookieValue4) + requireTokenRequestToBeHandled(issuer1, downstreamAuthCode1, issuer1JWKS, issuer1) requireTokenRequestToBeHandled(issuer2, downstreamAuthCode2, issuer2JWKS, issuer2) // Hostnames are case-insensitive, so test that we can handle that. requireTokenRequestToBeHandled(issuer1DifferentCaseHostname, downstreamAuthCode3, issuer1JWKS, issuer1) requireTokenRequestToBeHandled(issuer2DifferentCaseHostname, downstreamAuthCode4, issuer2JWKS, issuer2) + requireTokenRequestToBeHandled(issuer1DifferentCaseHostname, downstreamAuthCode5, issuer1JWKS, issuer1) + requireTokenRequestToBeHandled(issuer2DifferentCaseHostname, downstreamAuthCode6, issuer2JWKS, issuer2) } when("given some valid providers via SetFederationDomains()", func() {