use multiple IDPs in manager_test.go

This commit is contained in:
Ryan Richard 2023-07-18 16:22:21 -07:00
parent e42e3ca421
commit 64f41d0d0c

View File

@ -50,9 +50,14 @@ func TestManager(t *testing.T) {
issuer2 = "https://example.com/some/path/more/deeply/nested/path" // note that this is a sub-path of the other issuer url issuer2 = "https://example.com/some/path/more/deeply/nested/path" // note that this is a sub-path of the other issuer url
issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path"
issuer2KeyID = "issuer2-key" issuer2KeyID = "issuer2-key"
upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" upstreamIDPAuthorizationURL1 = "https://test-upstream.com/auth1"
upstreamIDPName = "test-idp" upstreamIDPAuthorizationURL2 = "https://test-upstream.com/auth2"
upstreamResourceUID = "test-resource-uid" upstreamIDPDisplayName1 = "test-idp-display-name-1"
upstreamIDPDisplayName2 = "test-idp-display-name-2"
upstreamIDPName1 = "test-idp-1"
upstreamIDPName2 = "test-idp-2"
upstreamResourceUID1 = "test-resource-uid-1"
upstreamResourceUID2 = "test-resource-uid-2"
upstreamIDPType = "oidc" upstreamIDPType = "oidc"
downstreamClientID = "pinniped-cli" downstreamClientID = "pinniped-cli"
downstreamRedirectURL = "http://127.0.0.1:12345/callback" downstreamRedirectURL = "http://127.0.0.1:12345/callback"
@ -82,7 +87,7 @@ func TestManager(t *testing.T) {
r.False(fallbackHandlerWasCalled) r.False(fallbackHandlerWasCalled)
// Minimal check to ensure that the right discovery endpoint was called // Minimal check to ensure that the right discovery endpoint was called
r.Equal(http.StatusOK, recorder.Code) r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder)
responseBody, err := io.ReadAll(recorder.Body) responseBody, err := io.ReadAll(recorder.Body)
r.NoError(err) r.NoError(err)
parsedDiscoveryResult := discovery.Metadata{} parsedDiscoveryResult := discovery.Metadata{}
@ -92,7 +97,7 @@ func TestManager(t *testing.T) {
r.Equal(parsedDiscoveryResult.SupervisorDiscovery.PinnipedIDPsEndpoint, expectedIssuer+oidc.PinnipedIDPsPathV1Alpha1) r.Equal(parsedDiscoveryResult.SupervisorDiscovery.PinnipedIDPsEndpoint, expectedIssuer+oidc.PinnipedIDPsPathV1Alpha1)
} }
requirePinnipedIDPsDiscoveryRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedIDPName, expectedIDPType string, expectedFlows []string) { requirePinnipedIDPsDiscoveryRequestToBeHandled := func(requestIssuer, requestURLSuffix string, expectedIDPNames []string, expectedIDPTypes string, expectedFlows []string) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.PinnipedIDPsPathV1Alpha1+requestURLSuffix)) subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.PinnipedIDPsPathV1Alpha1+requestURLSuffix))
@ -102,12 +107,18 @@ func TestManager(t *testing.T) {
expectedFlowsJSON, err := json.Marshal(expectedFlows) expectedFlowsJSON, err := json.Marshal(expectedFlows)
require.NoError(t, err) require.NoError(t, err)
expectedIDPJSONList := []string{}
for i := range expectedIDPNames {
expectedIDPJSONList = append(expectedIDPJSONList, fmt.Sprintf(`{"name":"%s","type":"%s","flows":%s}`,
expectedIDPNames[i], expectedIDPTypes, expectedFlowsJSON))
}
// Minimal check to ensure that the right IDP discovery endpoint was called // Minimal check to ensure that the right IDP discovery endpoint was called
r.Equal(http.StatusOK, recorder.Code) r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder)
responseBody, err := io.ReadAll(recorder.Body) responseBody, err := io.ReadAll(recorder.Body)
r.NoError(err) r.NoError(err)
r.Equal( r.Equal(
fmt.Sprintf(`{"pinniped_identity_providers":[{"name":"%s","type":"%s","flows":%s}]}`+"\n", expectedIDPName, expectedIDPType, expectedFlowsJSON), fmt.Sprintf(`{"pinniped_identity_providers":[%s]}`+"\n", strings.Join(expectedIDPJSONList, ",")),
string(responseBody), string(responseBody),
) )
} }
@ -120,7 +131,7 @@ func TestManager(t *testing.T) {
r.False(fallbackHandlerWasCalled) r.False(fallbackHandlerWasCalled)
// Minimal check to ensure that the right endpoint was called // Minimal check to ensure that the right endpoint was called
r.Equal(http.StatusSeeOther, recorder.Code) r.Equal(http.StatusSeeOther, recorder.Code, "unexpected response:", recorder)
actualLocation := recorder.Header().Get("Location") actualLocation := recorder.Header().Get("Location")
r.True( r.True(
strings.HasPrefix(actualLocation, expectedRedirectLocationPrefix), strings.HasPrefix(actualLocation, expectedRedirectLocationPrefix),
@ -159,7 +170,7 @@ func TestManager(t *testing.T) {
// Check just enough of the response to ensure that we wired up the callback endpoint correctly. // Check just enough of the response to ensure that we wired up the callback endpoint correctly.
// The endpoint's own unit tests cover everything else. // The endpoint's own unit tests cover everything else.
r.Equal(http.StatusSeeOther, recorder.Code) r.Equal(http.StatusSeeOther, recorder.Code, "unexpected response:", recorder)
actualLocation := recorder.Header().Get("Location") actualLocation := recorder.Header().Get("Location")
r.True( r.True(
strings.HasPrefix(actualLocation, downstreamRedirectURL), strings.HasPrefix(actualLocation, downstreamRedirectURL),
@ -171,7 +182,7 @@ func TestManager(t *testing.T) {
actualLocationQueryParams := parsedLocation.Query() actualLocationQueryParams := parsedLocation.Query()
r.Contains(actualLocationQueryParams, "code") r.Contains(actualLocationQueryParams, "code")
r.Equal("openid username groups", actualLocationQueryParams.Get("scope")) r.Equal("openid username groups", actualLocationQueryParams.Get("scope"))
r.Equal("some-state-value-with-enough-bytes-to-exceed-min-allowed", actualLocationQueryParams.Get("state")) r.Equal("some-state", actualLocationQueryParams.Get("state"))
// Make sure that we wired up the callback endpoint to use kube storage for fosite sessions. // Make sure that we wired up the callback endpoint to use kube storage for fosite sessions.
r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3, r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+3,
@ -198,8 +209,8 @@ func TestManager(t *testing.T) {
r.False(fallbackHandlerWasCalled) r.False(fallbackHandlerWasCalled)
// Minimal check to ensure that the right endpoint was called // Minimal check to ensure that the right endpoint was called
r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder)
var body map[string]interface{} var body map[string]interface{}
r.Equal(http.StatusOK, recorder.Code)
r.NoError(json.Unmarshal(recorder.Body.Bytes(), &body)) r.NoError(json.Unmarshal(recorder.Body.Bytes(), &body))
r.Contains(body, "id_token") r.Contains(body, "id_token")
r.Contains(body, "access_token") r.Contains(body, "access_token")
@ -228,7 +239,7 @@ func TestManager(t *testing.T) {
r.False(fallbackHandlerWasCalled) r.False(fallbackHandlerWasCalled)
// Minimal check to ensure that the right JWKS endpoint was called // Minimal check to ensure that the right JWKS endpoint was called
r.Equal(http.StatusOK, recorder.Code) r.Equal(http.StatusOK, recorder.Code, "unexpected response:", recorder)
responseBody, err := io.ReadAll(recorder.Body) responseBody, err := io.ReadAll(recorder.Body)
r.NoError(err) r.NoError(err)
parsedJWKSResult := jose.JSONWebKeySet{} parsedJWKSResult := jose.JSONWebKeySet{}
@ -246,31 +257,53 @@ func TestManager(t *testing.T) {
} }
dynamicJWKSProvider = jwks.NewDynamicJWKSProvider() dynamicJWKSProvider = jwks.NewDynamicJWKSProvider()
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) parsedUpstreamIDPAuthorizationURL1, err := url.Parse(upstreamIDPAuthorizationURL1)
r.NoError(err)
parsedUpstreamIDPAuthorizationURL2, err := url.Parse(upstreamIDPAuthorizationURL2)
r.NoError(err) r.NoError(err)
federationDomainIDPs = []*federationdomainproviders.FederationDomainIdentityProvider{ federationDomainIDPs = []*federationdomainproviders.FederationDomainIdentityProvider{
{ {
DisplayName: upstreamIDPName, DisplayName: upstreamIDPDisplayName1,
UID: upstreamResourceUID, UID: upstreamResourceUID1,
Transforms: idtransform.NewTransformationPipeline(),
},
{
DisplayName: upstreamIDPDisplayName2,
UID: upstreamResourceUID2,
Transforms: idtransform.NewTransformationPipeline(), Transforms: idtransform.NewTransformationPipeline(),
}, },
} }
idpLister := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). idpLister := oidctestutil.NewUpstreamIDPListerBuilder().
WithName(upstreamIDPName). WithOIDC(oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithClientID("test-client-id"). WithName(upstreamIDPName1).
WithResourceUID(upstreamResourceUID). WithClientID("test-client-id-1").
WithAuthorizationURL(*parsedUpstreamIDPAuthorizationURL). WithResourceUID(upstreamResourceUID1).
WithScopes([]string{"test-scope"}). WithAuthorizationURL(*parsedUpstreamIDPAuthorizationURL1).
WithIDTokenClaim("iss", "https://some-issuer.com"). WithScopes([]string{"test-scope"}).
WithIDTokenClaim("sub", "some-subject"). WithIDTokenClaim("iss", "https://some-issuer.com").
WithIDTokenClaim("username", "test-username"). WithIDTokenClaim("sub", "some-subject").
WithIDTokenClaim("groups", "test-group1"). WithIDTokenClaim("username", "test-username").
WithRefreshToken("some-opaque-token"). WithIDTokenClaim("groups", "test-group1").
WithoutAccessToken(). WithRefreshToken("some-opaque-token").
Build(), WithoutAccessToken().
).BuildDynamicUpstreamIDPProvider() Build(),
).
WithOIDC(oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName(upstreamIDPName2).
WithClientID("test-client-id-2").
WithResourceUID(upstreamResourceUID2).
WithAuthorizationURL(*parsedUpstreamIDPAuthorizationURL2).
WithScopes([]string{"test-scope"}).
WithIDTokenClaim("iss", "https://some-issuer.com").
WithIDTokenClaim("sub", "some-subject").
WithIDTokenClaim("username", "test-username").
WithIDTokenClaim("groups", "test-group1").
WithRefreshToken("some-opaque-token").
WithoutAccessToken().
Build(),
).BuildDynamicUpstreamIDPProvider()
kubeClient = fake.NewSimpleClientset() kubeClient = fake.NewSimpleClientset()
secretsClient := kubeClient.CoreV1().Secrets("some-namespace") secretsClient := kubeClient.CoreV1().Secrets("some-namespace")
@ -326,14 +359,14 @@ func TestManager(t *testing.T) {
requireDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2) requireDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2)
requireDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2) requireDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2)
requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows)
requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows)
requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "?some=query", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2, "?some=query", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows)
// Hostnames are case-insensitive, so test that we can handle that. // Hostnames are case-insensitive, so test that we can handle that.
requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1DifferentCaseHostname, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer1DifferentCaseHostname, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows)
requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows)
requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", upstreamIDPName, upstreamIDPType, upstreamIDPFlows) requirePinnipedIDPsDiscoveryRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", []string{upstreamIDPDisplayName1, upstreamIDPDisplayName2}, upstreamIDPType, upstreamIDPFlows)
issuer1JWKS := requireJWKSRequestToBeHandled(issuer1, "", issuer1KeyID) issuer1JWKS := requireJWKSRequestToBeHandled(issuer1, "", issuer1KeyID)
issuer2JWKS := requireJWKSRequestToBeHandled(issuer2, "", issuer2KeyID) issuer2JWKS := requireJWKSRequestToBeHandled(issuer2, "", issuer2KeyID)
@ -344,35 +377,50 @@ func TestManager(t *testing.T) {
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID)
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID)
authRequestParams := "?" + url.Values{ authRequestParamsIDP1 := "?" + url.Values{
"pinniped_idp_name": []string{upstreamIDPName}, "pinniped_idp_name": []string{upstreamIDPDisplayName1},
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"openid profile email username groups"}, "scope": []string{"openid profile email username groups"},
"client_id": []string{downstreamClientID}, "client_id": []string{downstreamClientID},
"state": []string{"some-state-value-with-enough-bytes-to-exceed-min-allowed"}, "state": []string{"some-state"},
"nonce": []string{"some-nonce-value-with-enough-bytes-to-exceed-min-allowed"}, "nonce": []string{"some-nonce-value-with-enough-bytes-to-exceed-min-allowed"},
"code_challenge": []string{testutil.SHA256(downstreamPKCECodeVerifier)}, "code_challenge": []string{testutil.SHA256(downstreamPKCECodeVerifier)},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURL}, "redirect_uri": []string{downstreamRedirectURL},
}.Encode() }.Encode()
requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) authRequestParamsIDP2 := "?" + url.Values{
requireAuthorizationRequestToBeHandled(issuer2, authRequestParams, upstreamIDPAuthorizationURL) "pinniped_idp_name": []string{upstreamIDPDisplayName2},
"response_type": []string{"code"},
"scope": []string{"openid profile email username groups"},
"client_id": []string{downstreamClientID},
"state": []string{"some-state"},
"nonce": []string{"some-nonce-value-with-enough-bytes-to-exceed-min-allowed"},
"code_challenge": []string{testutil.SHA256(downstreamPKCECodeVerifier)},
"code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURL},
}.Encode()
requireAuthorizationRequestToBeHandled(issuer1, authRequestParamsIDP1, upstreamIDPAuthorizationURL1)
requireAuthorizationRequestToBeHandled(issuer2, authRequestParamsIDP1, upstreamIDPAuthorizationURL1)
requireAuthorizationRequestToBeHandled(issuer1, authRequestParamsIDP2, upstreamIDPAuthorizationURL2)
requireAuthorizationRequestToBeHandled(issuer2, authRequestParamsIDP2, upstreamIDPAuthorizationURL2)
// Hostnames are case-insensitive, so test that we can handle that. // Hostnames are case-insensitive, so test that we can handle that.
csrfCookieValue1, upstreamStateParam1 := csrfCookieValue1, upstreamStateParam1 :=
requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParamsIDP1, upstreamIDPAuthorizationURL1)
csrfCookieValue2, upstreamStateParam2 := csrfCookieValue2, upstreamStateParam2 :=
requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParamsIDP1, upstreamIDPAuthorizationURL1)
callbackRequestParams1 := "?" + url.Values{ csrfCookieValue3, upstreamStateParam3 :=
"code": []string{"some-fake-code"}, requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParamsIDP2, upstreamIDPAuthorizationURL2)
"state": []string{upstreamStateParam1}, csrfCookieValue4, upstreamStateParam4 :=
}.Encode() requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParamsIDP2, upstreamIDPAuthorizationURL2)
callbackRequestParams2 := "?" + url.Values{
"code": []string{"some-fake-code"}, callbackRequestParams1 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam1}}.Encode()
"state": []string{upstreamStateParam2}, callbackRequestParams2 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam2}}.Encode()
}.Encode() callbackRequestParams3 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam3}}.Encode()
callbackRequestParams4 := "?" + url.Values{"code": []string{"some-fake-code"}, "state": []string{upstreamStateParam4}}.Encode()
downstreamAuthCode1 := requireCallbackRequestToBeHandled(issuer1, callbackRequestParams1, csrfCookieValue1) downstreamAuthCode1 := requireCallbackRequestToBeHandled(issuer1, callbackRequestParams1, csrfCookieValue1)
downstreamAuthCode2 := requireCallbackRequestToBeHandled(issuer2, callbackRequestParams2, csrfCookieValue2) downstreamAuthCode2 := requireCallbackRequestToBeHandled(issuer2, callbackRequestParams2, csrfCookieValue2)
@ -381,12 +429,17 @@ func TestManager(t *testing.T) {
downstreamAuthCode3 := requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams1, csrfCookieValue1) downstreamAuthCode3 := requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams1, csrfCookieValue1)
downstreamAuthCode4 := requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams2, csrfCookieValue2) downstreamAuthCode4 := requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams2, csrfCookieValue2)
downstreamAuthCode5 := requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams3, csrfCookieValue3)
downstreamAuthCode6 := requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams4, csrfCookieValue4)
requireTokenRequestToBeHandled(issuer1, downstreamAuthCode1, issuer1JWKS, issuer1) requireTokenRequestToBeHandled(issuer1, downstreamAuthCode1, issuer1JWKS, issuer1)
requireTokenRequestToBeHandled(issuer2, downstreamAuthCode2, issuer2JWKS, issuer2) requireTokenRequestToBeHandled(issuer2, downstreamAuthCode2, issuer2JWKS, issuer2)
// Hostnames are case-insensitive, so test that we can handle that. // Hostnames are case-insensitive, so test that we can handle that.
requireTokenRequestToBeHandled(issuer1DifferentCaseHostname, downstreamAuthCode3, issuer1JWKS, issuer1) requireTokenRequestToBeHandled(issuer1DifferentCaseHostname, downstreamAuthCode3, issuer1JWKS, issuer1)
requireTokenRequestToBeHandled(issuer2DifferentCaseHostname, downstreamAuthCode4, issuer2JWKS, issuer2) requireTokenRequestToBeHandled(issuer2DifferentCaseHostname, downstreamAuthCode4, issuer2JWKS, issuer2)
requireTokenRequestToBeHandled(issuer1DifferentCaseHostname, downstreamAuthCode5, issuer1JWKS, issuer1)
requireTokenRequestToBeHandled(issuer2DifferentCaseHostname, downstreamAuthCode6, issuer2JWKS, issuer2)
} }
when("given some valid providers via SetFederationDomains()", func() { when("given some valid providers via SetFederationDomains()", func() {