Merge pull request #1389 from vmware-tanzu/error_assertions

Accept both old and new cert error strings on MacOS in test assertions
This commit is contained in:
Joshua Casey 2023-01-24 15:06:40 -06:00 committed by GitHub
commit d2afdfaf9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 278 additions and 244 deletions

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package cmd package cmd
@ -107,7 +107,7 @@ func TestGetKubeconfig(t *testing.T) {
wantLogs func(string, string) []string wantLogs func(string, string) []string
wantError bool wantError bool
wantStdout func(string, string) string wantStdout func(string, string) string
wantStderr func(string, string) string wantStderr func(string, string) testutil.RequireErrorStringFunc
wantOptionsCount int wantOptionsCount int
wantAPIGroupSuffix string wantAPIGroupSuffix string
}{ }{
@ -164,8 +164,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: invalid argument "./does/not/exist" for "--oidc-ca-bundle" flag: could not read CA bundle path: open ./does/not/exist: no such file or directory` + "\n" return testutil.WantExactErrorString(`Error: invalid argument "./does/not/exist" for "--oidc-ca-bundle" flag: could not read CA bundle path: open ./does/not/exist: no such file or directory` + "\n")
}, },
}, },
{ {
@ -177,8 +177,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: invalid argument "./does/not/exist" for "--concierge-ca-bundle" flag: could not read CA bundle path: open ./does/not/exist: no such file or directory` + "\n" return testutil.WantExactErrorString(`Error: invalid argument "./does/not/exist" for "--concierge-ca-bundle" flag: could not read CA bundle path: open ./does/not/exist: no such file or directory` + "\n")
}, },
}, },
{ {
@ -189,8 +189,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not load --kubeconfig: stat ./does/not/exist: no such file or directory` + "\n" return testutil.WantExactErrorString(`Error: could not load --kubeconfig: stat ./does/not/exist: no such file or directory` + "\n")
}, },
}, },
{ {
@ -202,8 +202,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not load --kubeconfig/--kubeconfig-context: no such context "invalid"` + "\n" return testutil.WantExactErrorString(`Error: could not load --kubeconfig/--kubeconfig-context: no such context "invalid"` + "\n")
}, },
}, },
{ {
@ -215,8 +215,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not load --kubeconfig/--kubeconfig-context: no such cluster "invalid-cluster"` + "\n" return testutil.WantExactErrorString(`Error: could not load --kubeconfig/--kubeconfig-context: no such cluster "invalid-cluster"` + "\n")
}, },
}, },
{ {
@ -228,8 +228,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not load --kubeconfig/--kubeconfig-context: no such user "invalid-user"` + "\n" return testutil.WantExactErrorString(`Error: could not load --kubeconfig/--kubeconfig-context: no such user "invalid-user"` + "\n")
}, },
}, },
{ {
@ -241,8 +241,8 @@ func TestGetKubeconfig(t *testing.T) {
}, },
getClientsetErr: fmt.Errorf("some kube error"), getClientsetErr: fmt.Errorf("some kube error"),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not configure Kubernetes client: some kube error` + "\n" return testutil.WantExactErrorString(`Error: could not configure Kubernetes client: some kube error` + "\n")
}, },
}, },
{ {
@ -253,8 +253,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no CredentialIssuers were found` + "\n" return testutil.WantExactErrorString(`Error: no CredentialIssuers were found` + "\n")
}, },
}, },
{ {
@ -271,8 +271,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: credentialissuers.config.concierge.pinniped.dev "does-not-exist" not found` + "\n" return testutil.WantExactErrorString(`Error: credentialissuers.config.concierge.pinniped.dev "does-not-exist" not found` + "\n")
}, },
}, },
{ {
@ -295,8 +295,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: webhookauthenticators.authentication.concierge.pinniped.dev "test-authenticator" not found` + "\n" return testutil.WantExactErrorString(`Error: webhookauthenticators.authentication.concierge.pinniped.dev "test-authenticator" not found` + "\n")
}, },
}, },
{ {
@ -319,8 +319,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: jwtauthenticators.authentication.concierge.pinniped.dev "test-authenticator" not found` + "\n" return testutil.WantExactErrorString(`Error: jwtauthenticators.authentication.concierge.pinniped.dev "test-authenticator" not found` + "\n")
}, },
}, },
{ {
@ -343,8 +343,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: invalid authenticator type "invalid", supported values are "webhook" and "jwt"` + "\n" return testutil.WantExactErrorString(`Error: invalid authenticator type "invalid", supported values are "webhook" and "jwt"` + "\n")
}, },
}, },
{ {
@ -374,8 +374,8 @@ func TestGetKubeconfig(t *testing.T) {
}, },
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: failed to list JWTAuthenticator objects for autodiscovery: some list error` + "\n" return testutil.WantExactErrorString(`Error: failed to list JWTAuthenticator objects for autodiscovery: some list error` + "\n")
}, },
}, },
{ {
@ -405,8 +405,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: failed to list WebhookAuthenticator objects for autodiscovery: some list error` + "\n" return testutil.WantExactErrorString(`Error: failed to list WebhookAuthenticator objects for autodiscovery: some list error` + "\n")
}, },
}, },
{ {
@ -427,8 +427,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no authenticators were found` + "\n" return testutil.WantExactErrorString(`Error: no authenticators were found` + "\n")
}, },
}, },
{ {
@ -457,8 +457,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: multiple authenticators were found, so the --concierge-authenticator-type/--concierge-authenticator-name flags must be specified` + "\n" return testutil.WantExactErrorString(`Error: multiple authenticators were found, so the --concierge-authenticator-type/--concierge-authenticator-name flags must be specified` + "\n")
}, },
}, },
{ {
@ -491,8 +491,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not autodiscover --concierge-mode` + "\n" return testutil.WantExactErrorString(`Error: could not autodiscover --concierge-mode` + "\n")
}, },
}, },
{ {
@ -553,8 +553,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: autodiscovered Concierge CA bundle is invalid: illegal base64 data at input byte 7` + "\n" return testutil.WantExactErrorString(`Error: autodiscovered Concierge CA bundle is invalid: illegal base64 data at input byte 7` + "\n")
}, },
}, },
{ {
@ -580,8 +580,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not autodiscover --oidc-issuer and none was provided` + "\n" return testutil.WantExactErrorString(`Error: could not autodiscover --oidc-issuer and none was provided` + "\n")
}, },
}, },
{ {
@ -635,8 +635,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: tried to autodiscover --oidc-ca-bundle, but JWTAuthenticator test-authenticator has invalid spec.tls.certificateAuthorityData: illegal base64 data at input byte 7` + "\n" return testutil.WantExactErrorString(`Error: tried to autodiscover --oidc-ca-bundle, but JWTAuthenticator test-authenticator has invalid spec.tls.certificateAuthorityData: illegal base64 data at input byte 7` + "\n")
}, },
}, },
{ {
@ -675,8 +675,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: request audience is not allowed to include the substring '.pinniped.dev': some-test-audience.pinniped.dev-invalid-substring` + "\n" return testutil.WantExactErrorString(`Error: request audience is not allowed to include the substring '.pinniped.dev': some-test-audience.pinniped.dev-invalid-substring` + "\n")
}, },
}, },
{ {
@ -706,8 +706,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: request audience is not allowed to include the substring '.pinniped.dev': some-test-audience.pinniped.dev-invalid-substring` + "\n" return testutil.WantExactErrorString(`Error: request audience is not allowed to include the substring '.pinniped.dev': some-test-audience.pinniped.dev-invalid-substring` + "\n")
}, },
}, },
{ {
@ -738,8 +738,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: could not determine the Pinniped executable path: some OS error` + "\n" return testutil.WantExactErrorString(`Error: could not determine the Pinniped executable path: some OS error` + "\n")
}, },
}, },
{ {
@ -767,8 +767,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: only one of --static-token and --static-token-env can be specified` + "\n" return testutil.WantExactErrorString(`Error: only one of --static-token and --static-token-env can be specified` + "\n")
}, },
}, },
{ {
@ -779,8 +779,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: invalid API group suffix: a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character (e.g. 'example.com', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*')` + "\n" return testutil.WantExactErrorString(`Error: invalid API group suffix: a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character (e.g. 'example.com', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*')` + "\n")
}, },
}, },
{ {
@ -811,8 +811,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return "Error: while fetching OIDC discovery data from issuer: 400 Bad Request: {}\n" return testutil.WantExactErrorString("Error: while fetching OIDC discovery data from issuer: 400 Bad Request: {}\n")
}, },
}, },
{ {
@ -847,8 +847,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return fmt.Sprintf( return testutil.WantSprintfErrorString(
"Error: while fetching OIDC discovery data from issuer: oidc: issuer did not match the issuer returned by provider, expected \"%s\" got \"https://wrong-issuer.com\"\n", "Error: while fetching OIDC discovery data from issuer: oidc: issuer did not match the issuer returned by provider, expected \"%s\" got \"https://wrong-issuer.com\"\n",
issuerURL) issuerURL)
}, },
@ -882,8 +882,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return "Error: unable to fetch IDP discovery data from issuer: unexpected http response status: 400 Bad Request\n" return testutil.WantExactErrorString("Error: unable to fetch IDP discovery data from issuer: unexpected http response status: 400 Bad Request\n")
}, },
}, },
{ {
@ -920,10 +920,10 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: multiple Supervisor upstream identity providers were found, ` + return testutil.WantExactErrorString(`Error: multiple Supervisor upstream identity providers were found, ` +
`so the --upstream-identity-provider-name/--upstream-identity-provider-type flags must be specified. ` + `so the --upstream-identity-provider-name/--upstream-identity-provider-type flags must be specified. ` +
`Found these upstreams: [{"name":"some-ldap-idp","type":"ldap"},{"name":"some-oidc-idp","type":"oidc"}]` + "\n" `Found these upstreams: [{"name":"some-ldap-idp","type":"ldap"},{"name":"some-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -956,8 +956,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return "Error: while fetching OIDC discovery data from issuer: oidc: failed to decode provider discovery object: got Content-Type = application/json, but could not unmarshal as JSON: invalid character 'h' in literal true (expecting 'r')\n" return testutil.WantExactErrorString("Error: while fetching OIDC discovery data from issuer: oidc: failed to decode provider discovery object: got Content-Type = application/json, but could not unmarshal as JSON: invalid character 'h' in literal true (expecting 'r')\n")
}, },
}, },
{ {
@ -989,8 +989,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return "Error: unable to fetch IDP discovery data from issuer: could not parse response JSON: invalid character 'h' in literal true (expecting 'r')\n" return testutil.WantExactErrorString("Error: unable to fetch IDP discovery data from issuer: could not parse response JSON: invalid character 'h' in literal true (expecting 'r')\n")
}, },
}, },
{ {
@ -1025,8 +1025,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return fmt.Sprintf("Error: while fetching OIDC discovery data from issuer: Get \"%s/.well-known/openid-configuration\": %s\n", issuerURL, testutil.X509UntrustedCertError("Acme Co")) return testutil.WantX509UntrustedCertErrorString(fmt.Sprintf("Error: while fetching OIDC discovery data from issuer: Get \"%s/.well-known/openid-configuration\": %%s\n", issuerURL), "Acme Co")
}, },
}, },
{ {
@ -1063,8 +1063,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: while fetching OIDC discovery data from issuer: parse "https%://bad-issuer-url/.well-known/openid-configuration": first path segment in URL cannot contain colon` + "\n" return testutil.WantExactErrorString(`Error: while fetching OIDC discovery data from issuer: parse "https%://bad-issuer-url/.well-known/openid-configuration": first path segment in URL cannot contain colon` + "\n")
}, },
}, },
{ {
@ -1102,8 +1102,8 @@ func TestGetKubeconfig(t *testing.T) {
} }
}, },
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: while forming request to IDP discovery URL: parse "https%://illegal_url": first path segment in URL cannot contain colon` + "\n" return testutil.WantExactErrorString(`Error: while forming request to IDP discovery URL: parse "https%://illegal_url": first path segment in URL cannot contain colon` + "\n")
}, },
}, },
{ {
@ -1128,9 +1128,9 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no Supervisor upstream identity providers with name "does-not-exist-idp" of type "ldap" were found.` + return testutil.WantExactErrorString(`Error: no Supervisor upstream identity providers with name "does-not-exist-idp" of type "ldap" were found.` +
` Found these upstreams: [{"name":"some-ldap-idp","type":"ldap"},{"name":"some-other-ldap-idp","type":"ldap"}]` + "\n" ` Found these upstreams: [{"name":"some-ldap-idp","type":"ldap"},{"name":"some-other-ldap-idp","type":"ldap"}]` + "\n")
}, },
}, },
{ {
@ -1156,10 +1156,10 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: multiple Supervisor upstream identity providers of type "ldap" were found,` + return testutil.WantExactErrorString(`Error: multiple Supervisor upstream identity providers of type "ldap" were found,` +
` so the --upstream-identity-provider-name flag must be specified.` + ` so the --upstream-identity-provider-name flag must be specified.` +
` Found these upstreams: [{"name":"some-ldap-idp","type":"ldap"},{"name":"some-other-ldap-idp","type":"ldap"},{"name":"some-oidc-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n" ` Found these upstreams: [{"name":"some-ldap-idp","type":"ldap"},{"name":"some-other-ldap-idp","type":"ldap"},{"name":"some-oidc-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -1184,10 +1184,10 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: multiple Supervisor upstream identity providers with name "my-idp" were found,` + return testutil.WantExactErrorString(`Error: multiple Supervisor upstream identity providers with name "my-idp" were found,` +
` so the --upstream-identity-provider-type flag must be specified.` + ` so the --upstream-identity-provider-type flag must be specified.` +
` Found these upstreams: [{"name":"my-idp","type":"ldap"},{"name":"my-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n" ` Found these upstreams: [{"name":"my-idp","type":"ldap"},{"name":"my-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -1211,9 +1211,9 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no Supervisor upstream identity providers of type "ldap" were found.` + return testutil.WantExactErrorString(`Error: no Supervisor upstream identity providers of type "ldap" were found.` +
` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n" ` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -1236,9 +1236,9 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no Supervisor upstream identity providers of type "ldap" were found.` + return testutil.WantExactErrorString(`Error: no Supervisor upstream identity providers of type "ldap" were found.` +
` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"}]` + "\n" ` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -1262,9 +1262,9 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no Supervisor upstream identity providers with name "my-nonexistent-idp" were found.` + return testutil.WantExactErrorString(`Error: no Supervisor upstream identity providers with name "my-nonexistent-idp" were found.` +
` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n" ` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"},{"name":"some-other-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -1287,9 +1287,9 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no Supervisor upstream identity providers with name "my-nonexistent-idp" were found.` + return testutil.WantExactErrorString(`Error: no Supervisor upstream identity providers with name "my-nonexistent-idp" were found.` +
` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"}]` + "\n" ` Found these upstreams: [{"name":"some-oidc-idp","type":"oidc"}]` + "\n")
}, },
}, },
{ {
@ -1312,9 +1312,9 @@ func TestGetKubeconfig(t *testing.T) {
] ]
}`), }`),
wantError: true, wantError: true,
wantStderr: func(issuerCABundle string, issuerURL string) string { wantStderr: func(issuerCABundle string, issuerURL string) testutil.RequireErrorStringFunc {
return `Error: no client flow "my-nonexistent-flow" for Supervisor upstream identity provider "some-oidc-idp" of type "oidc" were found.` + return testutil.WantExactErrorString(`Error: no client flow "my-nonexistent-flow" for Supervisor upstream identity provider "some-oidc-idp" of type "oidc" were found.` +
` Found these flows: [non-matching-flow-1 non-matching-flow-2]` + "\n" ` Found these flows: [non-matching-flow-1 non-matching-flow-2]` + "\n")
}, },
}, },
{ {
@ -3015,11 +3015,12 @@ func TestGetKubeconfig(t *testing.T) {
} }
require.Equal(t, expectedStdout, stdout.String(), "unexpected stdout") require.Equal(t, expectedStdout, stdout.String(), "unexpected stdout")
expectedStderr := "" actualStderr := stderr.String()
if tt.wantStderr != nil { if tt.wantStderr != nil {
expectedStderr = tt.wantStderr(issuerCABundle, issuerEndpoint) testutil.RequireErrorString(t, actualStderr, tt.wantStderr(issuerCABundle, issuerEndpoint))
} else {
require.Empty(t, actualStderr, "unexpected stderr")
} }
require.Equal(t, expectedStderr, stderr.String(), "unexpected stderr")
}) })
} }
} }

1
go.mod
View File

@ -124,7 +124,6 @@ require (
github.com/spf13/cast v1.5.0 // indirect github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect
github.com/stoewer/go-strcase v1.2.0 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/subosito/gotenv v1.4.0 // indirect github.com/subosito/gotenv v1.4.0 // indirect
github.com/tdewolff/parse/v2 v2.6.4 // indirect github.com/tdewolff/parse/v2 v2.6.4 // indirect
go.etcd.io/etcd/api/v3 v3.5.5 // indirect go.etcd.io/etcd/api/v3 v3.5.5 // indirect

1
go.sum
View File

@ -537,7 +537,6 @@ github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package jwtcachefiller package jwtcachefiller
@ -188,7 +188,7 @@ func TestController(t *testing.T) {
syncKey controllerlib.Key syncKey controllerlib.Key
jwtAuthenticators []runtime.Object jwtAuthenticators []runtime.Object
wantClose bool wantClose bool
wantErr string wantErr testutil.RequireErrorStringFunc
wantLogs []string wantLogs []string
wantCacheEntries int wantCacheEntries int
wantUsernameClaim string wantUsernameClaim string
@ -350,7 +350,7 @@ func TestController(t *testing.T) {
Spec: *missingTLSJWTAuthenticatorSpec, Spec: *missingTLSJWTAuthenticatorSpec,
}, },
}, },
wantErr: `failed to build jwt authenticator: could not initialize provider: Get "` + goodIssuer + `/.well-known/openid-configuration": ` + testutil.X509UntrustedCertError("Acme Co"), wantErr: testutil.WantX509UntrustedCertErrorString(`failed to build jwt authenticator: could not initialize provider: Get "`+goodIssuer+`/.well-known/openid-configuration": %s`, "Acme Co"),
}, },
{ {
name: "invalid jwt authenticator CA", name: "invalid jwt authenticator CA",
@ -363,7 +363,7 @@ func TestController(t *testing.T) {
Spec: *invalidTLSJWTAuthenticatorSpec, Spec: *invalidTLSJWTAuthenticatorSpec,
}, },
}, },
wantErr: "failed to build jwt authenticator: invalid TLS configuration: illegal base64 data at input byte 7", wantErr: testutil.WantExactErrorString("failed to build jwt authenticator: invalid TLS configuration: illegal base64 data at input byte 7"),
}, },
} }
@ -391,8 +391,8 @@ func TestController(t *testing.T) {
syncCtx := controllerlib.Context{Context: ctx, Key: tt.syncKey} syncCtx := controllerlib.Context{Context: ctx, Key: tt.syncKey}
if err := controllerlib.TestSync(t, controller, syncCtx); tt.wantErr != "" { if err := controllerlib.TestSync(t, controller, syncCtx); tt.wantErr != nil {
require.EqualError(t, err, tt.wantErr) testutil.RequireErrorStringFromErr(t, err, tt.wantErr)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
} }
@ -490,9 +490,8 @@ func TestController(t *testing.T) {
rsp, authenticated, err = cachedAuthenticator.AuthenticateToken(context.Background(), jwt) rsp, authenticated, err = cachedAuthenticator.AuthenticateToken(context.Background(), jwt)
return !isNotInitialized(err), nil return !isNotInitialized(err), nil
}) })
if test.wantErrorRegexp != "" { if test.wantErr != nil {
require.Error(t, err) testutil.RequireErrorStringFromErr(t, err, test.wantErr)
require.Regexp(t, test.wantErrorRegexp, err.Error())
} else { } else {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, test.wantResponse, rsp) require.Equal(t, test.wantResponse, rsp)
@ -528,7 +527,7 @@ func testTableForAuthenticateTokenTests(
jwtSignature func(key *interface{}, algo *jose.SignatureAlgorithm, kid *string) jwtSignature func(key *interface{}, algo *jose.SignatureAlgorithm, kid *string)
wantResponse *authenticator.Response wantResponse *authenticator.Response
wantAuthenticated bool wantAuthenticated bool
wantErrorRegexp string wantErr testutil.RequireErrorStringFunc
distributedGroupsClaimURL string distributedGroupsClaimURL string
} { } {
tests := []struct { tests := []struct {
@ -537,7 +536,7 @@ func testTableForAuthenticateTokenTests(
jwtSignature func(key *interface{}, algo *jose.SignatureAlgorithm, kid *string) jwtSignature func(key *interface{}, algo *jose.SignatureAlgorithm, kid *string)
wantResponse *authenticator.Response wantResponse *authenticator.Response
wantAuthenticated bool wantAuthenticated bool
wantErrorRegexp string wantErr testutil.RequireErrorStringFunc
distributedGroupsClaimURL string distributedGroupsClaimURL string
}{ }{
{ {
@ -594,14 +593,14 @@ func testTableForAuthenticateTokenTests(
jwtClaims: func(claims *jwt.Claims, groups *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, groups *interface{}, username *string) {
}, },
distributedGroupsClaimURL: issuer + "/not_found_claim_source", distributedGroupsClaimURL: issuer + "/not_found_claim_source",
wantErrorRegexp: `oidc: could not expand distributed claims: while getting distributed claim "` + expectedGroupsClaim + `": error while getting distributed claim JWT: 404 Not Found`, wantErr: testutil.WantMatchingErrorString(`oidc: could not expand distributed claims: while getting distributed claim "` + expectedGroupsClaim + `": error while getting distributed claim JWT: 404 Not Found`),
}, },
{ {
name: "distributed groups doesn't return the right claim", name: "distributed groups doesn't return the right claim",
jwtClaims: func(claims *jwt.Claims, groups *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, groups *interface{}, username *string) {
}, },
distributedGroupsClaimURL: issuer + "/wrong_claim_source", distributedGroupsClaimURL: issuer + "/wrong_claim_source",
wantErrorRegexp: `oidc: could not expand distributed claims: jwt returned by distributed claim endpoint "` + issuer + `/wrong_claim_source" did not contain claim: `, wantErr: testutil.WantMatchingErrorString(`oidc: could not expand distributed claims: jwt returned by distributed claim endpoint "` + issuer + `/wrong_claim_source" did not contain claim: `),
}, },
{ {
name: "good token with groups as string", name: "good token with groups as string",
@ -633,7 +632,7 @@ func testTableForAuthenticateTokenTests(
jwtClaims: func(_ *jwt.Claims, groups *interface{}, username *string) { jwtClaims: func(_ *jwt.Claims, groups *interface{}, username *string) {
*groups = map[string]string{"not an array": "or a string"} *groups = map[string]string{"not an array": "or a string"}
}, },
wantErrorRegexp: "oidc: parse groups claim \"" + expectedGroupsClaim + "\": json: cannot unmarshal object into Go value of type string", wantErr: testutil.WantMatchingErrorString("oidc: parse groups claim \"" + expectedGroupsClaim + "\": json: cannot unmarshal object into Go value of type string"),
}, },
{ {
name: "bad token with wrong issuer", name: "bad token with wrong issuer",
@ -648,42 +647,42 @@ func testTableForAuthenticateTokenTests(
jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) {
claims.Audience = nil claims.Audience = nil
}, },
wantErrorRegexp: `oidc: verify token: oidc: expected audience "some-audience" got \[\]`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: oidc: expected audience "some-audience" got \[\]`),
}, },
{ {
name: "bad token with wrong audience", name: "bad token with wrong audience",
jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) {
claims.Audience = []string{"wrong-audience"} claims.Audience = []string{"wrong-audience"}
}, },
wantErrorRegexp: `oidc: verify token: oidc: expected audience "some-audience" got \["wrong-audience"\]`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: oidc: expected audience "some-audience" got \["wrong-audience"\]`),
}, },
{ {
name: "bad token with nbf in the future", name: "bad token with nbf in the future",
jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) {
claims.NotBefore = jwt.NewNumericDate(time.Date(3020, 2, 3, 4, 5, 6, 7, time.UTC)) claims.NotBefore = jwt.NewNumericDate(time.Date(3020, 2, 3, 4, 5, 6, 7, time.UTC))
}, },
wantErrorRegexp: `oidc: verify token: oidc: current time .* before the nbf \(not before\) time: 3020-.*`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: oidc: current time .* before the nbf \(not before\) time: 3020-.*`),
}, },
{ {
name: "bad token with exp in past", name: "bad token with exp in past",
jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) {
claims.Expiry = jwt.NewNumericDate(time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC)) claims.Expiry = jwt.NewNumericDate(time.Date(1, 2, 3, 4, 5, 6, 7, time.UTC))
}, },
wantErrorRegexp: `oidc: verify token: oidc: token is expired \(Token Expiry: .+`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: oidc: token is expired \(Token Expiry: .+`),
}, },
{ {
name: "bad token without exp", name: "bad token without exp",
jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) {
claims.Expiry = nil claims.Expiry = nil
}, },
wantErrorRegexp: `oidc: verify token: oidc: token is expired \(Token Expiry: .+`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: oidc: token is expired \(Token Expiry: .+`),
}, },
{ {
name: "token does not have username claim", name: "token does not have username claim",
jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) { jwtClaims: func(claims *jwt.Claims, _ *interface{}, username *string) {
*username = "" *username = ""
}, },
wantErrorRegexp: `oidc: parse username claims "` + expectedUsernameClaim + `": claim not present`, wantErr: testutil.WantMatchingErrorString(`oidc: parse username claims "` + expectedUsernameClaim + `": claim not present`),
}, },
{ {
name: "signing key is wrong", name: "signing key is wrong",
@ -693,7 +692,7 @@ func testTableForAuthenticateTokenTests(
require.NoError(t, err) require.NoError(t, err)
*algo = jose.ES256 *algo = jose.ES256
}, },
wantErrorRegexp: `oidc: verify token: failed to verify signature: failed to verify id token signature`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: failed to verify signature: failed to verify id token signature`),
}, },
{ {
name: "signing algo is unsupported", name: "signing algo is unsupported",
@ -703,7 +702,7 @@ func testTableForAuthenticateTokenTests(
require.NoError(t, err) require.NoError(t, err)
*algo = jose.ES384 *algo = jose.ES384
}, },
wantErrorRegexp: `oidc: verify token: oidc: id token signed with unsupported algorithm, expected \["RS256" "ES256"\] got "ES384"`, wantErr: testutil.WantMatchingErrorString(`oidc: verify token: oidc: id token signed with unsupported algorithm, expected \["RS256" "ES256"\] got "ES384"`),
}, },
} }

View File

@ -1,10 +1,11 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package testutil package testutil
import ( import (
"context" "context"
"fmt"
"mime" "mime"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -125,3 +126,63 @@ func requireSecurityHeaders(t *testing.T, response *httptest.ResponseRecorder) {
// This check is more relaxed since Fosite can override the base header we set. // This check is more relaxed since Fosite can override the base header we set.
require.Contains(t, response.Header().Get("Cache-Control"), "no-store") require.Contains(t, response.Header().Get("Cache-Control"), "no-store")
} }
type RequireErrorStringFunc func(t *testing.T, actualErrorStr string)
// RequireErrorStringFromErr can be used to make assertions about errors in tests.
func RequireErrorStringFromErr(t *testing.T, actualError error, requireFunc RequireErrorStringFunc) {
require.Error(t, actualError)
requireFunc(t, actualError.Error())
}
// RequireErrorString can be used to make assertions about error strings in tests.
func RequireErrorString(t *testing.T, actualErrorStr string, requireFunc RequireErrorStringFunc) {
requireFunc(t, actualErrorStr)
}
// WantExactErrorString can be used to set up an expected value for an error string in a test table.
// Use when you want to express that the expected string must be an exact match.
func WantExactErrorString(wantErrStr string) RequireErrorStringFunc {
return func(t *testing.T, actualErrorStr string) {
require.Equal(t, wantErrStr, actualErrorStr)
}
}
// WantSprintfErrorString can be used to set up an expected value for an error string in a test table.
// Use when you want to express that an expected string built using fmt.Sprintf semantics must be an exact match.
func WantSprintfErrorString(wantErrSprintfSpecifier string, a ...interface{}) RequireErrorStringFunc {
wantErrStr := fmt.Sprintf(wantErrSprintfSpecifier, a...)
return func(t *testing.T, actualErrorStr string) {
require.Equal(t, wantErrStr, actualErrorStr)
}
}
// WantMatchingErrorString can be used to set up an expected value for an error string in a test table.
// Use when you want to express that the expected regexp must be a match.
func WantMatchingErrorString(wantErrRegexp string) RequireErrorStringFunc {
return func(t *testing.T, actualErrorStr string) {
require.Regexp(t, wantErrRegexp, actualErrorStr)
}
}
// WantX509UntrustedCertErrorString can be used to set up an expected value for an error string in a test table.
// expectedErrorFormatString must contain exactly one formatting verb, which should usually be %s, which will
// be replaced by the platform-specific X509 untrusted certs error string and then compared against expectedCommonName.
func WantX509UntrustedCertErrorString(expectedErrorFormatSpecifier string, expectedCommonName string) RequireErrorStringFunc {
// Starting in Go 1.18.1, and until it was fixed in Go 1.19.5, Go on MacOS had an incorrect error string.
// We don't care which error string was returned, as long as it is either the normal error string from
// the Go x509 library, or the error string that was accidentally returned from the Go x509 library in
// those versions of Go on MacOS which had the bug.
return func(t *testing.T, actualErrorStr string) {
// This is the MacOS error string starting in Go 1.18.1, and until it was fixed in Go 1.19.5.
macOSErr := fmt.Sprintf(`x509: “%s” certificate is not trusted`, expectedCommonName)
// This is the normal Go x509 library error string.
standardErr := `x509: certificate signed by unknown authority`
allowedErrorStrings := []string{
fmt.Sprintf(expectedErrorFormatSpecifier, macOSErr),
fmt.Sprintf(expectedErrorFormatSpecifier, standardErr),
}
// Allow either.
require.Contains(t, allowedErrorStrings, actualErrorStr)
}
}

View File

@ -1,19 +0,0 @@
// Copyright 2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
import (
"fmt"
"runtime"
)
func X509UntrustedCertError(commonName string) string {
if runtime.GOOS == "darwin" {
// Golang use's macos' x509 verification APIs on darwin.
// This output slightly different error messages than golang's
// own x509 verification.
return fmt.Sprintf(`x509: “%s” certificate is not trusted`, commonName)
}
return `x509: certificate signed by unknown authority`
}

View File

@ -1,4 +1,4 @@
// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2021-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package upstreamldap package upstreamldap
@ -179,7 +179,7 @@ func TestEndUserAuthentication(t *testing.T) {
searchMocks func(conn *mockldapconn.MockConn) searchMocks func(conn *mockldapconn.MockConn)
bindEndUserMocks func(conn *mockldapconn.MockConn) bindEndUserMocks func(conn *mockldapconn.MockConn)
dialError error dialError error
wantError string wantError testutil.RequireErrorStringFunc
wantToSkipDial bool wantToSkipDial bool
wantAuthResponse *authenticators.Response wantAuthResponse *authenticators.Response
wantUnauthenticated bool wantUnauthenticated bool
@ -711,7 +711,7 @@ func TestEndUserAuthentication(t *testing.T) {
Return(exampleGroupSearchResult, nil).Times(1) Return(exampleGroupSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: "found 0 values for attribute \"some-attribute-to-check-during-refresh\" while searching for user \"some-upstream-username\", but expected 1 result", wantError: testutil.WantExactErrorString("found 0 values for attribute \"some-attribute-to-check-during-refresh\" while searching for user \"some-upstream-username\", but expected 1 result"),
}, },
{ {
name: "when dial fails", name: "when dial fails",
@ -719,7 +719,7 @@ func TestEndUserAuthentication(t *testing.T) {
password: testUpstreamPassword, password: testUpstreamPassword,
providerConfig: providerConfig(nil), providerConfig: providerConfig(nil),
dialError: errors.New("some dial error"), dialError: errors.New("some dial error"),
wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), wantError: testutil.WantSprintfErrorString(`error dialing host "%s": some dial error`, testHost),
}, },
{ {
name: "when the UsernameAttribute is dn and there is not a user search filter provided", name: "when the UsernameAttribute is dn and there is not a user search filter provided",
@ -730,7 +730,7 @@ func TestEndUserAuthentication(t *testing.T) {
p.UserSearch.Filter = "" p.UserSearch.Filter = ""
}), }),
wantToSkipDial: true, wantToSkipDial: true,
wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, wantError: testutil.WantExactErrorString(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`),
}, },
{ {
name: "when binding as the bind user returns an error", name: "when binding as the bind user returns an error",
@ -741,7 +741,7 @@ func TestEndUserAuthentication(t *testing.T) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`error binding as "%s" before user search: some bind error`, testBindUsername), wantError: testutil.WantSprintfErrorString(`error binding as "%s" before user search: some bind error`, testBindUsername),
}, },
{ {
name: "when searching for the user returns an error", name: "when searching for the user returns an error",
@ -753,7 +753,7 @@ func TestEndUserAuthentication(t *testing.T) {
conn.EXPECT().Search(expectedUserSearch(nil)).Return(nil, errors.New("some user search error")).Times(1) conn.EXPECT().Search(expectedUserSearch(nil)).Return(nil, errors.New("some user search error")).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: `error searching for user: some user search error`, wantError: testutil.WantExactErrorString(`error searching for user: some user search error`),
}, },
{ {
name: "when searching for the user's groups returns an error", name: "when searching for the user's groups returns an error",
@ -767,7 +767,7 @@ func TestEndUserAuthentication(t *testing.T) {
Return(nil, errors.New("some group search error")).Times(1) Return(nil, errors.New("some group search error")).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`error searching for group memberships for user with DN "%s": some group search error`, testUserSearchResultDNValue), wantError: testutil.WantSprintfErrorString(`error searching for group memberships for user with DN "%s": some group search error`, testUserSearchResultDNValue),
}, },
{ {
name: "when searching for the user returns no results", name: "when searching for the user returns no results",
@ -798,7 +798,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`searching for user "%s" resulted in 2 search results, but expected 1 result`, testUpstreamUsername), wantError: testutil.WantSprintfErrorString(`searching for user "%s" resulted in 2 search results, but expected 1 result`, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user without a DN", name: "when searching for the user returns a user without a DN",
@ -814,7 +814,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`searching for user "%s" resulted in search result without DN`, testUpstreamUsername), wantError: testutil.WantSprintfErrorString(`searching for user "%s" resulted in search result without DN`, testUpstreamUsername),
}, },
{ {
name: "when searching for the user's groups returns a group without a DN", name: "when searching for the user's groups returns a group without a DN",
@ -845,7 +845,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`searching for group memberships for user with DN "%s" resulted in search result without DN`, `searching for group memberships for user with DN "%s" resulted in search result without DN`,
testUserSearchResultDNValue), testUserSearchResultDNValue),
}, },
@ -868,7 +868,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, `found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`,
testUserSearchUsernameAttribute, testUpstreamUsername), testUserSearchUsernameAttribute, testUpstreamUsername),
}, },
@ -901,7 +901,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`error searching for group memberships for user with DN "%s": found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, `error searching for group memberships for user with DN "%s": found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`,
testUserSearchResultDNValue, testGroupSearchGroupNameAttribute, testUserSearchResultDNValue), testUserSearchResultDNValue, testGroupSearchGroupNameAttribute, testUserSearchResultDNValue),
}, },
@ -928,7 +928,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, `found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`,
testUserSearchUsernameAttribute, testUpstreamUsername), testUserSearchUsernameAttribute, testUpstreamUsername),
}, },
@ -964,7 +964,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`error searching for group memberships for user with DN "%s": found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, `error searching for group memberships for user with DN "%s": found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`,
testUserSearchResultDNValue, testGroupSearchGroupNameAttribute, testUserSearchResultDNValue), testUserSearchResultDNValue, testGroupSearchGroupNameAttribute, testUserSearchResultDNValue),
}, },
@ -988,7 +988,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, `found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`,
testUserSearchUsernameAttribute, testUpstreamUsername), testUserSearchUsernameAttribute, testUpstreamUsername),
}, },
@ -1021,7 +1021,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf( wantError: testutil.WantSprintfErrorString(
`error searching for group memberships for user with DN "%s": found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, `error searching for group memberships for user with DN "%s": found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`,
testUserSearchResultDNValue, testGroupSearchGroupNameAttribute, testUserSearchResultDNValue), testUserSearchResultDNValue, testGroupSearchGroupNameAttribute, testUserSearchResultDNValue),
}, },
@ -1044,7 +1044,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: testutil.WantSprintfErrorString(`found 0 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user with too many values for the expected UID attribute", name: "when searching for the user returns a user with too many values for the expected UID attribute",
@ -1069,7 +1069,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: testutil.WantSprintfErrorString(`found 2 values for attribute "%s" while searching for user "%s", but expected 1 result`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{ {
name: "when searching for the user returns a user with an empty value for the expected UID attribute", name: "when searching for the user returns a user with an empty value for the expected UID attribute",
@ -1091,7 +1091,7 @@ func TestEndUserAuthentication(t *testing.T) {
}, nil).Times(1) }, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername), wantError: testutil.WantSprintfErrorString(`found empty value for attribute "%s" while searching for user "%s", but expected value to be non-empty`, testUserSearchUIDAttribute, testUpstreamUsername),
}, },
{ {
name: "when the group search has an override func that errors", name: "when the group search has an override func that errors",
@ -1109,7 +1109,7 @@ func TestEndUserAuthentication(t *testing.T) {
Return(exampleGroupSearchResult, nil).Times(1) Return(exampleGroupSearchResult, nil).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf("error finding groups for user %s: some error", testUserSearchResultDNValue), wantError: testutil.WantSprintfErrorString("error finding groups for user %s: some error", testUserSearchResultDNValue),
}, },
{ {
name: "when binding as the found user returns an error", name: "when binding as the found user returns an error",
@ -1127,7 +1127,7 @@ func TestEndUserAuthentication(t *testing.T) {
conn.EXPECT().Bind(testUserSearchResultDNValue, testUpstreamPassword).Return(errors.New("some bind error")).Times(1) conn.EXPECT().Bind(testUserSearchResultDNValue, testUpstreamPassword).Return(errors.New("some bind error")).Times(1)
}, },
skipDryRunAuthenticateUser: true, skipDryRunAuthenticateUser: true,
wantError: fmt.Sprintf(`error binding for user "%s" using provided password against DN "%s": some bind error`, testUpstreamUsername, testUserSearchResultDNValue), wantError: testutil.WantSprintfErrorString(`error binding for user "%s" using provided password against DN "%s": some bind error`, testUpstreamUsername, testUserSearchResultDNValue),
}, },
{ {
name: "when binding as the found user returns a specific invalid credentials error", name: "when binding as the found user returns a specific invalid credentials error",
@ -1194,8 +1194,8 @@ func TestEndUserAuthentication(t *testing.T) {
authResponse, authenticated, err := ldapProvider.AuthenticateUser(context.Background(), tt.username, tt.password, tt.grantedScopes) authResponse, authenticated, err := ldapProvider.AuthenticateUser(context.Background(), tt.username, tt.password, tt.grantedScopes)
require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted)
switch { switch {
case tt.wantError != "": case tt.wantError != nil:
require.EqualError(t, err, tt.wantError) testutil.RequireErrorStringFromErr(t, err, tt.wantError)
require.False(t, authenticated) require.False(t, authenticated)
require.Nil(t, authResponse) require.Nil(t, authResponse)
case tt.wantUnauthenticated: case tt.wantUnauthenticated:
@ -1226,8 +1226,8 @@ func TestEndUserAuthentication(t *testing.T) {
authResponse, authenticated, err = ldapProvider.DryRunAuthenticateUser(context.Background(), tt.username, tt.grantedScopes) authResponse, authenticated, err = ldapProvider.DryRunAuthenticateUser(context.Background(), tt.username, tt.grantedScopes)
require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted)
switch { switch {
case tt.wantError != "": case tt.wantError != nil:
require.EqualError(t, err, tt.wantError) testutil.RequireErrorStringFromErr(t, err, tt.wantError)
require.False(t, authenticated) require.False(t, authenticated)
require.Nil(t, authResponse) require.Nil(t, authResponse)
case tt.wantUnauthenticated: case tt.wantUnauthenticated:
@ -1852,7 +1852,7 @@ func TestTestConnection(t *testing.T) {
providerConfig *ProviderConfig providerConfig *ProviderConfig
setupMocks func(conn *mockldapconn.MockConn) setupMocks func(conn *mockldapconn.MockConn)
dialError error dialError error
wantError string wantError testutil.RequireErrorStringFunc
wantToSkipDial bool wantToSkipDial bool
}{ }{
{ {
@ -1867,7 +1867,7 @@ func TestTestConnection(t *testing.T) {
name: "when dial fails", name: "when dial fails",
providerConfig: providerConfig(nil), providerConfig: providerConfig(nil),
dialError: errors.New("some dial error"), dialError: errors.New("some dial error"),
wantError: fmt.Sprintf(`error dialing host "%s": some dial error`, testHost), wantError: testutil.WantSprintfErrorString(`error dialing host "%s": some dial error`, testHost),
}, },
{ {
name: "when binding as the bind user returns an error", name: "when binding as the bind user returns an error",
@ -1876,7 +1876,7 @@ func TestTestConnection(t *testing.T) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1)
conn.EXPECT().Close().Times(1) conn.EXPECT().Close().Times(1)
}, },
wantError: fmt.Sprintf(`error binding as "%s": some bind error`, testBindUsername), wantError: testutil.WantSprintfErrorString(`error binding as "%s": some bind error`, testBindUsername),
}, },
{ {
name: "when the config is invalid", name: "when the config is invalid",
@ -1886,7 +1886,7 @@ func TestTestConnection(t *testing.T) {
p.UserSearch.Filter = "" p.UserSearch.Filter = ""
}), }),
wantToSkipDial: true, wantToSkipDial: true,
wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, wantError: testutil.WantExactErrorString(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`),
}, },
} }
@ -1917,8 +1917,8 @@ func TestTestConnection(t *testing.T) {
require.Equal(t, !tt.wantToSkipDial, dialWasAttempted) require.Equal(t, !tt.wantToSkipDial, dialWasAttempted)
switch { switch {
case tt.wantError != "": case tt.wantError != nil:
require.EqualError(t, err, tt.wantError) testutil.RequireErrorStringFromErr(t, err, tt.wantError)
default: default:
require.NoError(t, err) require.NoError(t, err)
} }
@ -2010,7 +2010,7 @@ func TestRealTLSDialing(t *testing.T) {
connProto LDAPConnectionProtocol connProto LDAPConnectionProtocol
caBundle []byte caBundle []byte
context context.Context context context.Context
wantError string wantError testutil.RequireErrorStringFunc
}{ }{
{ {
name: "happy path", name: "happy path",
@ -2025,7 +2025,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: caForTestServerWithBadCertName.Bundle(), caBundle: caForTestServerWithBadCertName.Bundle(),
connProto: TLS, connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": x509: certificate is valid for 10.2.3.4, not 127.0.0.1`, wantError: testutil.WantExactErrorString(`LDAP Result Code 200 "Network Error": x509: certificate is valid for 10.2.3.4, not 127.0.0.1`),
}, },
{ {
name: "invalid CA bundle with TLS", name: "invalid CA bundle with TLS",
@ -2033,7 +2033,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: []byte("not a ca bundle"), caBundle: []byte("not a ca bundle"),
connProto: TLS, connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`, wantError: testutil.WantExactErrorString(`LDAP Result Code 200 "Network Error": could not parse CA bundle`),
}, },
{ {
name: "invalid CA bundle with StartTLS", name: "invalid CA bundle with StartTLS",
@ -2041,7 +2041,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: []byte("not a ca bundle"), caBundle: []byte("not a ca bundle"),
connProto: StartTLS, connProto: StartTLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": could not parse CA bundle`, wantError: testutil.WantExactErrorString(`LDAP Result Code 200 "Network Error": could not parse CA bundle`),
}, },
{ {
name: "invalid host with TLS", name: "invalid host with TLS",
@ -2049,7 +2049,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: testServerCABundle, caBundle: testServerCABundle,
connProto: TLS, connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`, wantError: testutil.WantExactErrorString(`LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`),
}, },
{ {
name: "invalid host with StartTLS", name: "invalid host with StartTLS",
@ -2057,7 +2057,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: testServerCABundle, caBundle: testServerCABundle,
connProto: StartTLS, connProto: StartTLS,
context: context.Background(), context: context.Background(),
wantError: `LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`, wantError: testutil.WantExactErrorString(`LDAP Result Code 200 "Network Error": host "this:is:not:a:valid:hostname" is not a valid hostname or IP address`),
}, },
{ {
name: "missing CA bundle when it is required because the host is not using a trusted CA", name: "missing CA bundle when it is required because the host is not using a trusted CA",
@ -2065,7 +2065,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: nil, caBundle: nil,
connProto: TLS, connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": %s`, testutil.X509UntrustedCertError("Acme Co")), wantError: testutil.WantX509UntrustedCertErrorString(`LDAP Result Code 200 "Network Error": %s`, "Acme Co"),
}, },
{ {
name: "cannot connect to host", name: "cannot connect to host",
@ -2074,7 +2074,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: testServerCABundle, caBundle: testServerCABundle,
connProto: TLS, connProto: TLS,
context: context.Background(), context: context.Background(),
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort), wantError: testutil.WantSprintfErrorString(`LDAP Result Code 200 "Network Error": dial tcp %s: connect: connection refused`, recentlyClaimedHostAndPort),
}, },
{ {
name: "pays attention to the passed context", name: "pays attention to the passed context",
@ -2082,7 +2082,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: testServerCABundle, caBundle: testServerCABundle,
connProto: TLS, connProto: TLS,
context: alreadyCancelledContext, context: alreadyCancelledContext,
wantError: fmt.Sprintf(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort), wantError: testutil.WantSprintfErrorString(`LDAP Result Code 200 "Network Error": dial tcp %s: operation was canceled`, testServerHostAndPort),
}, },
{ {
name: "unsupported connection protocol", name: "unsupported connection protocol",
@ -2090,7 +2090,7 @@ func TestRealTLSDialing(t *testing.T) {
caBundle: testServerCABundle, caBundle: testServerCABundle,
connProto: "bad usage of this type", connProto: "bad usage of this type",
context: alreadyCancelledContext, context: alreadyCancelledContext,
wantError: `LDAP Result Code 200 "Network Error": did not specify valid ConnectionProtocol`, wantError: testutil.WantExactErrorString(`LDAP Result Code 200 "Network Error": did not specify valid ConnectionProtocol`),
}, },
} }
for _, test := range tests { for _, test := range tests {
@ -2106,9 +2106,9 @@ func TestRealTLSDialing(t *testing.T) {
if conn != nil { if conn != nil {
defer conn.Close() defer conn.Close()
} }
if tt.wantError != "" { if tt.wantError != nil {
require.Nil(t, conn) require.Nil(t, conn)
require.EqualError(t, err, tt.wantError) testutil.RequireErrorStringFromErr(t, err, tt.wantError)
} else { } else {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, conn) require.NotNil(t, conn)

View File

@ -487,8 +487,7 @@ func TestProviderConfig(t *testing.T) {
unreachableServer bool unreachableServer bool
returnStatusCodes []int returnStatusCodes []int
returnErrBodies []string returnErrBodies []string
wantErr string wantErr testutil.RequireErrorStringFunc
wantErrRegexp string // use either wantErr or wantErrRegexp
wantRetryableErrType bool // additionally assert error type when wantErr is non-empty wantRetryableErrType bool // additionally assert error type when wantErr is non-empty
wantNumRequests int wantNumRequests int
wantTokenTypeHint string wantTokenTypeHint string
@ -542,7 +541,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`, wantErr: testutil.WantExactErrorString(`server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`),
wantRetryableErrType: false, wantRetryableErrType: false,
wantNumRequests: 2, wantNumRequests: 2,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -552,7 +551,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`invalid JSON body`}, returnErrBodies: []string{`invalid JSON body`},
wantErr: `error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`, wantErr: testutil.WantExactErrorString(`error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`),
wantRetryableErrType: false, wantRetryableErrType: false,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -562,7 +561,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{``}, returnErrBodies: []string{``},
wantErr: `error parsing response body "" on response with status code 400: unexpected end of JSON input`, wantErr: testutil.WantExactErrorString(`error parsing response body "" on response with status code 400: unexpected end of JSON input`),
wantRetryableErrType: false, wantRetryableErrType: false,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -572,7 +571,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusForbidden}, returnStatusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: "server responded with status 403", wantErr: testutil.WantExactErrorString("server responded with status 403"),
wantRetryableErrType: false, wantRetryableErrType: false,
wantNumRequests: 2, wantNumRequests: 2,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -582,7 +581,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
wantErr: `server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`, wantErr: testutil.WantExactErrorString(`server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`),
wantRetryableErrType: false, wantRetryableErrType: false,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -592,7 +591,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusForbidden}, returnStatusCodes: []int{http.StatusForbidden},
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: "server responded with status 403", wantErr: testutil.WantExactErrorString("server responded with status 403"),
wantRetryableErrType: false, wantRetryableErrType: false,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -602,7 +601,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusServiceUnavailable}, // 503 returnStatusCodes: []int{http.StatusServiceUnavailable}, // 503
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: "retryable revocation error: server responded with status 503", wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 503"),
wantRetryableErrType: true, wantRetryableErrType: true,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -612,7 +611,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.AccessTokenType, tokenType: provider.AccessTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusServiceUnavailable}, // 400, 503 returnStatusCodes: []int{http.StatusBadRequest, http.StatusServiceUnavailable}, // 400, 503
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: "retryable revocation error: server responded with status 503", wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 503"),
wantRetryableErrType: true, wantRetryableErrType: true,
wantNumRequests: 2, wantNumRequests: 2,
wantTokenTypeHint: "access_token", wantTokenTypeHint: "access_token",
@ -622,7 +621,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{http.StatusInternalServerError}, // 500 returnStatusCodes: []int{http.StatusInternalServerError}, // 500
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: "retryable revocation error: server responded with status 500", wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 500"),
wantRetryableErrType: true, wantRetryableErrType: true,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -632,7 +631,7 @@ func TestProviderConfig(t *testing.T) {
tokenType: provider.RefreshTokenType, tokenType: provider.RefreshTokenType,
returnStatusCodes: []int{599}, // not defined by an RFC, but sometimes considered Network Connect Timeout Error returnStatusCodes: []int{599}, // not defined by an RFC, but sometimes considered Network Connect Timeout Error
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: "retryable revocation error: server responded with status 599", wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 599"),
wantRetryableErrType: true, wantRetryableErrType: true,
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
@ -641,7 +640,7 @@ func TestProviderConfig(t *testing.T) {
name: "retryable error when the server cannot be reached", name: "retryable error when the server cannot be reached",
tokenType: provider.AccessTokenType, tokenType: provider.AccessTokenType,
unreachableServer: true, unreachableServer: true,
wantErrRegexp: "^retryable revocation error: Post .*: dial tcp .*: connect: connection refused$", wantErr: testutil.WantMatchingErrorString("^retryable revocation error: Post .*: dial tcp .*: connect: connection refused$"),
wantRetryableErrType: true, wantRetryableErrType: true,
wantNumRequests: 0, wantNumRequests: 0,
}, },
@ -709,13 +708,8 @@ func TestProviderConfig(t *testing.T) {
require.Equal(t, tt.wantNumRequests, numRequests, require.Equal(t, tt.wantNumRequests, numRequests,
"did not make expected number of requests to revocation endpoint") "did not make expected number of requests to revocation endpoint")
if tt.wantErr != "" || tt.wantErrRegexp != "" { //nolint:nestif if tt.wantErr != nil {
if tt.wantErr != "" { testutil.RequireErrorStringFromErr(t, err, tt.wantErr)
require.EqualError(t, err, tt.wantErr)
} else {
require.Error(t, err)
require.Regexp(t, tt.wantErrRegexp, err.Error())
}
if tt.wantRetryableErrType { if tt.wantRetryableErrType {
require.ErrorAs(t, err, &provider.RetryableRevocationError{}) require.ErrorAs(t, err, &provider.RetryableRevocationError{})

View File

@ -1,4 +1,4 @@
// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2021-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package integration package integration
@ -75,7 +75,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
password string password string
grantedScopes []string grantedScopes []string
provider *upstreamldap.Provider provider *upstreamldap.Provider
wantError string wantError testutil.RequireErrorStringFunc
wantAuthResponse *authenticators.Response wantAuthResponse *authenticators.Response
wantUnauthenticated bool wantUnauthenticated bool
}{ }{
@ -248,7 +248,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.UserSearch.UsernameAttribute = "dn" p.UserSearch.UsernameAttribute = "dn"
p.UserSearch.Filter = "" p.UserSearch.Filter = ""
})), })),
wantError: `must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`, wantError: testutil.WantExactErrorString(`must specify UserSearch Filter when UserSearch UsernameAttribute is "dn"`),
}, },
{ {
name: "group search disabled", name: "group search disabled",
@ -352,21 +352,21 @@ func TestLDAPSearch_Parallel(t *testing.T) {
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "invalid-dn" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "invalid-dn" })),
wantError: `error binding as "invalid-dn" before user search: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, wantError: testutil.WantExactErrorString(`error binding as "invalid-dn" before user search: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`),
}, },
{ {
name: "when the bind user username is wrong", name: "when the bind user username is wrong",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" })),
wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, wantError: testutil.WantExactErrorString(`error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `),
}, },
{ {
name: "when the bind user password is wrong", name: "when the bind user password is wrong",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.BindPassword = "wrong-password" })),
wantError: `error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, wantError: testutil.WantExactErrorString(`error binding as "cn=admin,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `),
}, },
{ {
name: "when the bind user username is wrong with StartTLS: example of an error after successful connection with StartTLS", name: "when the bind user username is wrong with StartTLS: example of an error after successful connection with StartTLS",
@ -377,7 +377,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.ConnectionProtocol = upstreamldap.StartTLS p.ConnectionProtocol = upstreamldap.StartTLS
p.BindUsername = "cn=wrong,dc=pinniped,dc=dev" p.BindUsername = "cn=wrong,dc=pinniped,dc=dev"
})), })),
wantError: `error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `, wantError: testutil.WantExactErrorString(`error binding as "cn=wrong,dc=pinniped,dc=dev" before user search: LDAP Result Code 49 "Invalid Credentials": `),
}, },
{ {
name: "when the end user password is wrong", name: "when the end user password is wrong",
@ -405,14 +405,14 @@ func TestLDAPSearch_Parallel(t *testing.T) {
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "*" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Filter = "*" })),
wantError: `error searching for user: LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`, wantError: testutil.WantExactErrorString(`error searching for user: LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`),
}, },
{ {
name: "when the group search filter does not compile", name: "when the group search filter does not compile",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.Filter = "*" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.Filter = "*" })),
wantError: `error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`, wantError: testutil.WantExactErrorString(`error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 201 "Filter Compile Error": ldap: error parsing filter`),
}, },
{ {
name: "when there are too many search results for the user", name: "when there are too many search results for the user",
@ -421,14 +421,14 @@ func TestLDAPSearch_Parallel(t *testing.T) {
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.Filter = "objectClass=*" // overly broad search filter p.UserSearch.Filter = "objectClass=*" // overly broad search filter
})), })),
wantError: `error searching for user: LDAP Result Code 4 "Size Limit Exceeded": `, wantError: testutil.WantExactErrorString(`error searching for user: LDAP Result Code 4 "Size Limit Exceeded": `),
}, },
{ {
name: "when the server is unreachable with TLS", name: "when the server is unreachable with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedLocalhostPort })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + unusedLocalhostPort })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort), wantError: testutil.WantSprintfErrorString(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort),
}, },
{ {
name: "when the server is unreachable with StartTLS", name: "when the server is unreachable with StartTLS",
@ -438,14 +438,14 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.Host = "127.0.0.1:" + unusedLocalhostPort p.Host = "127.0.0.1:" + unusedLocalhostPort
p.ConnectionProtocol = upstreamldap.StartTLS p.ConnectionProtocol = upstreamldap.StartTLS
})), })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort), wantError: testutil.WantSprintfErrorString(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": dial tcp 127.0.0.1:%s: connect: connection refused`, unusedLocalhostPort, unusedLocalhostPort),
}, },
{ {
name: "when the server is not parsable with TLS", name: "when the server is not parsable with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "too:many:ports" })),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`, wantError: testutil.WantExactErrorString(`error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`),
}, },
{ {
name: "when the server is not parsable with StartTLS", name: "when the server is not parsable with StartTLS",
@ -456,14 +456,14 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.ConnectionProtocol = upstreamldap.StartTLS p.ConnectionProtocol = upstreamldap.StartTLS
p.Host = "too:many:ports" p.Host = "too:many:ports"
})), })),
wantError: `error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`, wantError: testutil.WantExactErrorString(`error dialing host "too:many:ports": LDAP Result Code 200 "Network Error": host "too:many:ports" is not a valid hostname or IP address`),
}, },
{ {
name: "when the CA bundle is not parsable with TLS", name: "when the CA bundle is not parsable with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = []byte("invalid-pem") })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapsLocalhostPort), wantError: testutil.WantSprintfErrorString(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapsLocalhostPort),
}, },
{ {
name: "when the CA bundle is not parsable with StartTLS", name: "when the CA bundle is not parsable with StartTLS",
@ -474,14 +474,14 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.ConnectionProtocol = upstreamldap.StartTLS p.ConnectionProtocol = upstreamldap.StartTLS
p.CABundle = []byte("invalid-pem") p.CABundle = []byte("invalid-pem")
})), })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapLocalhostPort), wantError: testutil.WantSprintfErrorString(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": could not parse CA bundle`, ldapLocalhostPort),
}, },
{ {
name: "when the CA bundle does not cause the host to be trusted with TLS", name: "when the CA bundle does not cause the host to be trusted with TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.CABundle = nil })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": %s`, ldapsLocalhostPort, testutil.X509UntrustedCertError("Pinniped Test")), wantError: testutil.WantX509UntrustedCertErrorString(fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": %%s`, ldapsLocalhostPort), "Pinniped Test"),
}, },
{ {
name: "when the CA bundle does not cause the host to be trusted with StartTLS", name: "when the CA bundle does not cause the host to be trusted with StartTLS",
@ -492,35 +492,35 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.ConnectionProtocol = upstreamldap.StartTLS p.ConnectionProtocol = upstreamldap.StartTLS
p.CABundle = nil p.CABundle = nil
})), })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": TLS handshake failed (%s)`, ldapLocalhostPort, testutil.X509UntrustedCertError("Pinniped Test")), wantError: testutil.WantX509UntrustedCertErrorString(fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": TLS handshake failed (%%s)`, ldapLocalhostPort), "Pinniped Test"),
}, },
{ {
name: "when trying to use TLS to connect to a port which only supports StartTLS", name: "when trying to use TLS to connect to a port which only supports StartTLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + ldapLocalhostPort })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.Host = "127.0.0.1:" + ldapLocalhostPort })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": EOF`, ldapLocalhostPort), wantError: testutil.WantSprintfErrorString(`error dialing host "127.0.0.1:%s": LDAP Result Code 200 "Network Error": EOF`, ldapLocalhostPort),
}, },
{ {
name: "when trying to use StartTLS to connect to a port which only supports TLS", name: "when trying to use StartTLS to connect to a port which only supports TLS",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.ConnectionProtocol = upstreamldap.StartTLS })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.ConnectionProtocol = upstreamldap.StartTLS })),
wantError: fmt.Sprintf(`error dialing host "127.0.0.1:%s": unable to read LDAP response packet: unexpected EOF`, ldapsLocalhostPort), wantError: testutil.WantSprintfErrorString(`error dialing host "127.0.0.1:%s": unable to read LDAP response packet: unexpected EOF`, ldapsLocalhostPort),
}, },
{ {
name: "when the UsernameAttribute attribute has multiple values in the entry", name: "when the UsernameAttribute attribute has multiple values in the entry",
username: "wally.ldap@example.com", username: "wally.ldap@example.com",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "mail" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "mail" })),
wantError: `found 2 values for attribute "mail" while searching for user "wally.ldap@example.com", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 2 values for attribute "mail" while searching for user "wally.ldap@example.com", but expected 1 result`),
}, },
{ {
name: "when the UIDAttribute attribute has multiple values in the entry", name: "when the UIDAttribute attribute has multiple values in the entry",
username: "wally", username: "wally",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "mail" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "mail" })),
wantError: `found 2 values for attribute "mail" while searching for user "wally", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 2 values for attribute "mail" while searching for user "wally", but expected 1 result`),
}, },
{ {
name: "when the UsernameAttribute attribute is not found in the entry", name: "when the UsernameAttribute attribute is not found in the entry",
@ -530,35 +530,35 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.UserSearch.Filter = "cn={}" p.UserSearch.Filter = "cn={}"
p.UserSearch.UsernameAttribute = "attr-does-not-exist" p.UserSearch.UsernameAttribute = "attr-does-not-exist"
})), })),
wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`),
}, },
{ {
name: "when the UIDAttribute attribute is not found in the entry", name: "when the UIDAttribute attribute is not found in the entry",
username: "wally", username: "wally",
password: "unused-because-error-is-before-bind", password: "unused-because-error-is-before-bind",
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "attr-does-not-exist" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "attr-does-not-exist" })),
wantError: `found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 0 values for attribute "attr-does-not-exist" while searching for user "wally", but expected 1 result`),
}, },
{ {
name: "when the UsernameAttribute has the wrong case", name: "when the UsernameAttribute has the wrong case",
username: "Seal", username: "Seal",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "SN" })), // this is case-sensitive provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UsernameAttribute = "SN" })), // this is case-sensitive
wantError: `found 0 values for attribute "SN" while searching for user "Seal", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 0 values for attribute "SN" while searching for user "Seal", but expected 1 result`),
}, },
{ {
name: "when the UIDAttribute has the wrong case", name: "when the UIDAttribute has the wrong case",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "SN" })), // this is case-sensitive provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.UIDAttribute = "SN" })), // this is case-sensitive
wantError: `found 0 values for attribute "SN" while searching for user "pinny", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 0 values for attribute "SN" while searching for user "pinny", but expected 1 result`),
}, },
{ {
name: "when the GroupNameAttribute has the wrong case", name: "when the GroupNameAttribute has the wrong case",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.GroupNameAttribute = "CN" })), // this is case-sensitive provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.GroupNameAttribute = "CN" })), // this is case-sensitive
wantError: `error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": found 0 values for attribute "CN" while searching for user "cn=pinny,ou=users,dc=pinniped,dc=dev", but expected 1 result`, wantError: testutil.WantExactErrorString(`error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": found 0 values for attribute "CN" while searching for user "cn=pinny,ou=users,dc=pinniped,dc=dev", but expected 1 result`),
}, },
{ {
name: "when the UsernameAttribute is DN and has the wrong case", name: "when the UsernameAttribute is DN and has the wrong case",
@ -568,7 +568,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
p.UserSearch.UsernameAttribute = "DN" // dn must be lower-case p.UserSearch.UsernameAttribute = "DN" // dn must be lower-case
p.UserSearch.Filter = "cn={}" p.UserSearch.Filter = "cn={}"
})), })),
wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`),
}, },
{ {
name: "when the UIDAttribute is DN and has the wrong case", name: "when the UIDAttribute is DN and has the wrong case",
@ -577,7 +577,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.UserSearch.UIDAttribute = "DN" // dn must be lower-case p.UserSearch.UIDAttribute = "DN" // dn must be lower-case
})), })),
wantError: `found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`, wantError: testutil.WantExactErrorString(`found 0 values for attribute "DN" while searching for user "pinny", but expected 1 result`),
}, },
{ {
name: "when the GroupNameAttribute is DN and has the wrong case", name: "when the GroupNameAttribute is DN and has the wrong case",
@ -586,35 +586,35 @@ func TestLDAPSearch_Parallel(t *testing.T) {
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) {
p.GroupSearch.GroupNameAttribute = "DN" // dn must be lower-case p.GroupSearch.GroupNameAttribute = "DN" // dn must be lower-case
})), })),
wantError: `error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": found 0 values for attribute "DN" while searching for user "cn=pinny,ou=users,dc=pinniped,dc=dev", but expected 1 result`, wantError: testutil.WantExactErrorString(`error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": found 0 values for attribute "DN" while searching for user "cn=pinny,ou=users,dc=pinniped,dc=dev", but expected 1 result`),
}, },
{ {
name: "when the user search base is invalid", name: "when the user search base is invalid",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "invalid-base" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "invalid-base" })),
wantError: `error searching for user: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, wantError: testutil.WantExactErrorString(`error searching for user: LDAP Result Code 34 "Invalid DN Syntax": invalid DN`),
}, },
{ {
name: "when the group search base is invalid", name: "when the group search base is invalid",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.Base = "invalid-base" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.Base = "invalid-base" })),
wantError: `error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 34 "Invalid DN Syntax": invalid DN`, wantError: testutil.WantExactErrorString(`error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 34 "Invalid DN Syntax": invalid DN`),
}, },
{ {
name: "when the user search base does not exist", name: "when the user search base does not exist",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.UserSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" })),
wantError: `error searching for user: LDAP Result Code 32 "No Such Object": `, wantError: testutil.WantExactErrorString(`error searching for user: LDAP Result Code 32 "No Such Object": `),
}, },
{ {
name: "when the group search base does not exist", name: "when the group search base does not exist",
username: "pinny", username: "pinny",
password: pinnyPassword, password: pinnyPassword,
provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" })), provider: upstreamldap.New(*providerConfig(func(p *upstreamldap.ProviderConfig) { p.GroupSearch.Base = "ou=does-not-exist,dc=pinniped,dc=dev" })),
wantError: `error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 32 "No Such Object": `, wantError: testutil.WantExactErrorString(`error searching for group memberships for user with DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 32 "No Such Object": `),
}, },
{ {
name: "when the user search base causes no search results", name: "when the user search base causes no search results",
@ -635,7 +635,7 @@ func TestLDAPSearch_Parallel(t *testing.T) {
username: "pinny", username: "pinny",
password: "", password: "",
provider: upstreamldap.New(*providerConfig(nil)), provider: upstreamldap.New(*providerConfig(nil)),
wantError: `error binding for user "pinny" using provided password against DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 206 "Empty password not allowed by the client": ldap: empty password not allowed by the client`, wantError: testutil.WantExactErrorString(`error binding for user "pinny" using provided password against DN "cn=pinny,ou=users,dc=pinniped,dc=dev": LDAP Result Code 206 "Empty password not allowed by the client": ldap: empty password not allowed by the client`),
}, },
{ {
name: "when the user has no password in their entry", name: "when the user has no password in their entry",
@ -655,8 +655,8 @@ func TestLDAPSearch_Parallel(t *testing.T) {
authResponse, authenticated, err := tt.provider.AuthenticateUser(ctx, tt.username, tt.password, tt.grantedScopes) authResponse, authenticated, err := tt.provider.AuthenticateUser(ctx, tt.username, tt.password, tt.grantedScopes)
switch { switch {
case tt.wantError != "": case tt.wantError != nil:
require.EqualError(t, err, tt.wantError) testutil.RequireErrorStringFromErr(t, err, tt.wantError)
require.False(t, authenticated, "expected the user not to be authenticated, but they were") require.False(t, authenticated, "expected the user not to be authenticated, but they were")
require.Nil(t, authResponse) require.Nil(t, authResponse)
case tt.wantUnauthenticated: case tt.wantUnauthenticated: