diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index e6917954..8736c84b 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -5,6 +5,7 @@ package auth import ( "context" + "errors" "fmt" "html" "net/http" @@ -36,6 +37,16 @@ import ( func TestAuthorizationEndpoint(t *testing.T) { const ( + passwordGrantUpstreamName = "some-password-granting-oidc-idp" + + oidcUpstreamIssuer = "https://my-upstream-issuer.com" + oidcUpstreamSubject = "abc123-some guid" // has a space character which should get escaped in URL + oidcUpstreamSubjectQueryEscaped = "abc123-some+guid" + oidcUpstreamUsername = "test-oidc-pinniped-username" + oidcUpstreamPassword = "test-oidc-pinniped-password" //nolint: gosec + oidcUpstreamUsernameClaim = "the-user-claim" + oidcUpstreamGroupsClaim = "the-groups-claim" + downstreamIssuer = "https://my-downstream-issuer.com/some-path" downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback" @@ -51,6 +62,8 @@ func TestAuthorizationEndpoint(t *testing.T) { require.Len(t, happyState, 8, "we expect fosite to allow 8 byte state params, so we want to test that boundary case") var ( + oidcUpstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} + fositeInvalidClientErrorBody = here.Doc(` { "error": "invalid_client", @@ -145,11 +158,31 @@ func TestAuthorizationEndpoint(t *testing.T) { upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") require.NoError(t, err) - upstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{ - Name: "some-oidc-idp", - ClientID: "some-client-id", - AuthorizationURL: *upstreamAuthURL, - Scopes: []string{"scope1", "scope2"}, // the scopes to request when starting the upstream authorization flow + upstreamOIDCIdentityProvider := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName("some-oidc-idp"). + WithClientID("some-client-id"). + WithAuthorizationURL(*upstreamAuthURL). + WithScopes([]string{"scope1", "scope2"}). // the scopes to request when starting the upstream authorization flow + WithAllowPasswordGrant(false). + WithPasswordGrantError(errors.New("should not have used password grant on this instance")). + Build() + + passwordGrantUpstreamOIDCIdentityProviderBuilder := func() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { + return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName(passwordGrantUpstreamName). + WithClientID("some-client-id"). + WithAuthorizationURL(*upstreamAuthURL). + WithScopes([]string{"scope1", "scope2"}). // the scopes to request when starting the upstream authorization flow + WithAllowPasswordGrant(false). + WithUsernameClaim(oidcUpstreamUsernameClaim). + WithGroupsClaim(oidcUpstreamGroupsClaim). + WithIDTokenClaim("iss", oidcUpstreamIssuer). + WithIDTokenClaim("sub", oidcUpstreamSubject). + WithIDTokenClaim(oidcUpstreamUsernameClaim, oidcUpstreamUsername). + WithIDTokenClaim(oidcUpstreamGroupsClaim, oidcUpstreamGroupMembership). + WithIDTokenClaim("other-claim", "should be ignored"). + WithAllowPasswordGrant(true). + WithUpstreamAuthcodeExchangeError(errors.New("should not have tried to exchange upstream authcode on this instance")) } happyLDAPUsername := "some-ldap-user" @@ -322,7 +355,7 @@ func TestAuthorizationEndpoint(t *testing.T) { type testCase struct { name string - idpLister provider.DynamicUpstreamIDPProvider + idps *oidctestutil.UpstreamIDPListerBuilder generateCSRF func() (csrftoken.CSRFToken, error) generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) @@ -345,7 +378,8 @@ func TestAuthorizationEndpoint(t *testing.T) { wantLocationHeader string wantUpstreamStateParamInLocationHeader bool - // For when the request was authenticated by an upstream LDAP provider and an authcode is being returned. + // Assertions for when an authcode should be returned, i.e. the request was authenticated by an + // upstream LDAP provider or an upstream OIDC password grant flow. wantRedirectLocationRegexp string wantDownstreamRedirectURI string wantDownstreamGrantedScopes []string @@ -357,11 +391,12 @@ func TestAuthorizationEndpoint(t *testing.T) { wantDownstreamPKCEChallengeMethod string wantDownstreamNonce string wantUnnecessaryStoredRecords int + wantPasswordGrantCall *expectedPasswordGrant } tests := []testCase{ { - name: "OIDC upstream happy path using GET without a CSRF cookie", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + name: "OIDC upstream browser flow happy path using GET without a CSRF cookie", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -376,9 +411,35 @@ func TestAuthorizationEndpoint(t *testing.T) { wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, + { + name: "OIDC upstream password grant happy path using GET", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(passwordGrantUpstreamOIDCIdentityProviderBuilder().Build()), + method: http.MethodGet, + path: happyGetRequestPath, + customUsernameHeader: pointer.StringPtr(oidcUpstreamUsername), + customPasswordHeader: pointer.StringPtr(oidcUpstreamPassword), + wantPasswordGrantCall: &expectedPasswordGrant{ + performedByUpstreamName: passwordGrantUpstreamName, + args: &oidctestutil.PasswordCredentialsGrantAndValidateTokensArgs{ + Username: oidcUpstreamUsername, + Password: oidcUpstreamPassword, + }}, + wantStatus: http.StatusFound, + wantContentType: htmlContentType, + wantRedirectLocationRegexp: happyAuthcodeDownstreamRedirectLocationRegexp, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, + wantDownstreamRequestedScopes: happyDownstreamScopesRequested, + wantDownstreamRedirectURI: downstreamRedirectURI, + wantDownstreamGrantedScopes: happyDownstreamScopesGranted, + wantDownstreamNonce: downstreamNonce, + wantDownstreamPKCEChallenge: downstreamPKCEChallenge, + wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, + }, { name: "LDAP upstream happy path using GET", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: happyGetRequestPath, customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -386,7 +447,6 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: htmlContentType, wantRedirectLocationRegexp: happyAuthcodeDownstreamRedirectLocationRegexp, - wantBodyStringWithLocationInHref: false, wantDownstreamIDTokenSubject: upstreamLDAPURL + "&sub=" + happyLDAPUID, wantDownstreamIDTokenUsername: happyLDAPUsernameFromAuthenticator, wantDownstreamIDTokenGroups: happyLDAPGroups, @@ -399,7 +459,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC upstream happy path using GET with a CSRF cookie", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -416,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC upstream happy path using POST", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -435,7 +495,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "LDAP upstream happy path using POST", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodPost, path: "/some/path", contentType: "application/x-www-form-urlencoded", @@ -445,7 +505,6 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: htmlContentType, wantRedirectLocationRegexp: happyAuthcodeDownstreamRedirectLocationRegexp, - wantBodyStringWithLocationInHref: false, wantDownstreamIDTokenSubject: upstreamLDAPURL + "&sub=" + happyLDAPUID, wantDownstreamIDTokenUsername: happyLDAPUsernameFromAuthenticator, wantDownstreamIDTokenGroups: happyLDAPGroups, @@ -458,7 +517,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC upstream happy path with prompt param login passed through to redirect uri", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -477,7 +536,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC upstream with error while decoding CSRF cookie just generates a new cookie and succeeds as usual", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -496,7 +555,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC upstream happy path when downstream redirect uri matches what is configured for client except for the port number", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -516,9 +575,9 @@ func TestAuthorizationEndpoint(t *testing.T) { wantBodyStringWithLocationInHref: true, }, { - name: "LDAP upstream happy path when downstream redirect uri matches what is configured for client except for the port number", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), - method: http.MethodGet, + name: "LDAP upstream happy path when downstream redirect uri matches what is configured for client except for the port number", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), + method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client }), @@ -527,7 +586,6 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: htmlContentType, wantRedirectLocationRegexp: downstreamRedirectURIWithDifferentPort + `\?code=([^&]+)&scope=openid&state=` + happyState, - wantBodyStringWithLocationInHref: false, wantDownstreamIDTokenSubject: upstreamLDAPURL + "&sub=" + happyLDAPUID, wantDownstreamIDTokenUsername: happyLDAPUsernameFromAuthenticator, wantDownstreamIDTokenGroups: happyLDAPGroups, @@ -540,7 +598,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "OIDC upstream happy path when downstream requested scopes include offline_access", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -559,7 +617,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error during upstream LDAP authentication", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&erroringUpstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&erroringUpstreamLDAPIdentityProvider), method: http.MethodGet, path: happyGetRequestPath, customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -570,7 +628,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "wrong upstream password for LDAP authentication", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: happyGetRequestPath, customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -582,7 +640,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "wrong upstream username for LDAP authentication", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: happyGetRequestPath, customUsernameHeader: pointer.StringPtr("wrong-username"), @@ -594,7 +652,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing upstream username on request for LDAP authentication", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: happyGetRequestPath, customUsernameHeader: nil, // do not send header @@ -606,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing upstream password on request for LDAP authentication", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: happyGetRequestPath, customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -618,7 +676,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream redirect uri does not match what is configured for client when using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -633,9 +691,9 @@ func TestAuthorizationEndpoint(t *testing.T) { wantBodyJSON: fositeInvalidRedirectURIErrorBody, }, { - name: "downstream redirect uri does not match what is configured for client when using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), - method: http.MethodGet, + name: "downstream redirect uri does not match what is configured for client when using LDAP upstream", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), + method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{ "redirect_uri": "http://127.0.0.1/does-not-match-what-is-configured-for-pinniped-cli-client", }), @@ -647,7 +705,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream client does not exist when using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -661,7 +719,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream client does not exist when using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"client_id": "invalid-client"}), wantStatus: http.StatusUnauthorized, @@ -670,7 +728,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "response type is unsupported when using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -685,7 +743,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "response type is unsupported when using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"response_type": "unsupported"}), wantStatus: http.StatusFound, @@ -695,7 +753,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream scopes do not match what is configured for client using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -710,7 +768,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream scopes do not match what is configured for client using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"scope": "openid tuna"}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -722,7 +780,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing response type in request using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -737,7 +795,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing response type in request using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"response_type": ""}), wantStatus: http.StatusFound, @@ -747,7 +805,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing client id in request using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -761,7 +819,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing client id in request using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"client_id": ""}), wantStatus: http.StatusUnauthorized, @@ -770,7 +828,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge in request using OIDC upstream", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -785,7 +843,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge in request using LDAP upstream", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge": ""}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -798,7 +856,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "invalid value for PKCE code_challenge_method in request using OIDC upstream", // https://tools.ietf.org/html/rfc7636#section-4.3 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -813,7 +871,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "invalid value for PKCE code_challenge_method in request using LDAP upstream", // https://tools.ietf.org/html/rfc7636#section-4.3 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "this-is-not-a-valid-pkce-alg"}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -826,7 +884,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "when PKCE code_challenge_method in request is `plain` using OIDC upstream", // https://tools.ietf.org/html/rfc7636#section-4.3 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -841,7 +899,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "when PKCE code_challenge_method in request is `plain` using LDAP upstream", // https://tools.ietf.org/html/rfc7636#section-4.3 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": "plain"}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -854,7 +912,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge_method in request using OIDC upstream", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -869,7 +927,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "missing PKCE code_challenge_method in request using LDAP upstream", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"code_challenge_method": ""}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -884,7 +942,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running // through that part of the fosite library when using an OIDC upstream. name: "prompt param is not allowed to have none and another legal value at the same time using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -901,7 +959,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running // through that part of the fosite library when using an LDAP upstream. name: "prompt param is not allowed to have none and another legal value at the same time using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login"}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -914,7 +972,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "happy path: downstream OIDC validations are skipped when the openid scope was not requested using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -933,9 +991,9 @@ func TestAuthorizationEndpoint(t *testing.T) { wantBodyStringWithLocationInHref: true, }, { - name: "happy path: downstream OIDC validations are skipped when the openid scope was not requested using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), - method: http.MethodGet, + name: "happy path: downstream OIDC validations are skipped when the openid scope was not requested using LDAP upstream", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), + method: http.MethodGet, // The following prompt value is illegal when openid is requested, but note that openid is not requested. path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -943,7 +1001,6 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: htmlContentType, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyState, // no scopes granted - wantBodyStringWithLocationInHref: false, wantDownstreamIDTokenSubject: upstreamLDAPURL + "&sub=" + happyLDAPUID, wantDownstreamIDTokenUsername: happyLDAPUsernameFromAuthenticator, wantDownstreamIDTokenGroups: happyLDAPGroups, @@ -956,7 +1013,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream state does not have enough entropy using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -971,7 +1028,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "downstream state does not have enough entropy using LDAP upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider), method: http.MethodGet, path: modifiedHappyGetRequestPath(map[string]string{"state": "short"}), customUsernameHeader: pointer.StringPtr(happyLDAPUsername), @@ -983,7 +1040,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while encoding upstream state param using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -997,7 +1054,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while encoding CSRF cookie value for new cookie using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -1011,7 +1068,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating CSRF token using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: sadCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -1025,7 +1082,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating nonce using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: sadNonceGenerator, @@ -1039,7 +1096,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "error while generating PKCE using OIDC upstream", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: sadPKCEGenerator, generateNonce: happyNonceGenerator, @@ -1053,7 +1110,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "no upstream providers are configured", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC().Build(), // empty + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(), // empty method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -1062,7 +1119,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "too many upstream providers are configured: multiple OIDC", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider).Build(), // more than one not allowed + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -1071,7 +1128,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "too many upstream providers are configured: multiple LDAP", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider, &upstreamLDAPIdentityProvider).Build(), // more than one not allowed + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&upstreamLDAPIdentityProvider, &upstreamLDAPIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -1080,7 +1137,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "too many upstream providers are configured: both OIDC and LDAP", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).WithLDAP(&upstreamLDAPIdentityProvider).Build(), // more than one not allowed + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider).WithLDAP(&upstreamLDAPIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -1089,7 +1146,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "PUT is a bad method", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -1098,7 +1155,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "PATCH is a bad method", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -1107,7 +1164,7 @@ func TestAuthorizationEndpoint(t *testing.T) { }, { name: "DELETE is a bad method", - idpLister: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&upstreamOIDCIdentityProvider).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -1117,7 +1174,8 @@ func TestAuthorizationEndpoint(t *testing.T) { } runOneTestCase := func(t *testing.T, test testCase, subject http.Handler, kubeOauthStore *oidc.KubeStorage, kubeClient *fake.Clientset, secretsClient v1.SecretInterface) { - req := httptest.NewRequest(test.method, test.path, strings.NewReader(test.body)) + reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") + req := httptest.NewRequest(test.method, test.path, strings.NewReader(test.body)).WithContext(reqContext) req.Header.Set("Content-Type", test.contentType) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) @@ -1137,6 +1195,15 @@ func TestAuthorizationEndpoint(t *testing.T) { testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) testutil.RequireSecurityHeaders(t, rsp) + if test.wantPasswordGrantCall != nil { + test.wantPasswordGrantCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToPasswordCredentialsGrantAndValidateTokens(t, + test.wantPasswordGrantCall.performedByUpstreamName, test.wantPasswordGrantCall.args, + ) + } else { + test.idps.RequireExactlyZeroCallsToPasswordCredentialsGrantAndValidateTokens(t) + } + actualLocation := rsp.Header().Get("Location") switch { case test.wantLocationHeader != "": @@ -1212,7 +1279,7 @@ func TestAuthorizationEndpoint(t *testing.T) { oauthHelperWithRealStorage, kubeOauthStore := createOauthHelperWithRealStorage(secretsClient) subject := NewHandler( downstreamIssuer, - test.idpLister, + test.idps.Build(), oauthHelperWithNullStorage, oauthHelperWithRealStorage, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder, @@ -1223,14 +1290,16 @@ func TestAuthorizationEndpoint(t *testing.T) { t.Run("allows upstream provider configuration to change between requests", func(t *testing.T) { test := tests[0] - require.Equal(t, "OIDC upstream happy path using GET without a CSRF cookie", test.name) // re-use the happy path test case + // Double-check that we are re-using the happy path test case here as we intend. + require.Equal(t, "OIDC upstream browser flow happy path using GET without a CSRF cookie", test.name) kubeClient := fake.NewSimpleClientset() secretsClient := kubeClient.CoreV1().Secrets("some-namespace") oauthHelperWithRealStorage, kubeOauthStore := createOauthHelperWithRealStorage(secretsClient) + idpLister := test.idps.Build() subject := NewHandler( downstreamIssuer, - test.idpLister, + idpLister, oauthHelperWithNullStorage, oauthHelperWithRealStorage, test.generateCSRF, test.generatePKCE, test.generateNonce, test.stateEncoder, test.cookieEncoder, @@ -1238,23 +1307,25 @@ func TestAuthorizationEndpoint(t *testing.T) { runOneTestCase(t, test, subject, kubeOauthStore, kubeClient, secretsClient) - // Call the setter to change the upstream IDP settings. - newProviderSettings := oidctestutil.TestUpstreamOIDCIdentityProvider{ - Name: "some-other-idp", - ClientID: "some-other-client-id", - AuthorizationURL: *upstreamAuthURL, - Scopes: []string{"other-scope1", "other-scope2"}, - } - test.idpLister.SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(&newProviderSettings)}) + // Call the idpLister's setter to change the upstream IDP settings. + newProviderSettings := oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName("some-other-new-idp-name"). + WithClientID("some-other-new-client-id"). + WithAuthorizationURL(*upstreamAuthURL). + WithScopes([]string{"some-other-new-scope1", "some-other-new-scope2"}). + Build() + idpLister.SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(newProviderSettings)}) // Update the expectations of the test case to match the new upstream IDP settings. test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(), map[string]string{ - "response_type": "code", - "access_type": "offline", - "scope": "other-scope1 other-scope2", - "client_id": "some-other-client-id", - "state": expectedUpstreamStateParam(nil, "", newProviderSettings.Name), + "response_type": "code", + "access_type": "offline", + "scope": "some-other-new-scope1 some-other-new-scope2", // updated expectation + "client_id": "some-other-new-client-id", // updated expectation + "state": expectedUpstreamStateParam( + nil, "", "some-other-new-idp-name", + ), // updated expectation "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": downstreamPKCEChallengeMethod, @@ -1282,6 +1353,11 @@ func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) { return "", fmt.Errorf("some encoding error") } +type expectedPasswordGrant struct { + performedByUpstreamName string + args *oidctestutil.PasswordCredentialsGrantAndValidateTokensArgs +} + func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) { t.Helper() actualLocationURL, err := url.Parse(actualURL) diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 23912944..e2b9a31c 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -21,20 +21,19 @@ import ( "go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/pkg/oidcclient/nonce" - "go.pinniped.dev/pkg/oidcclient/oidctypes" oidcpkce "go.pinniped.dev/pkg/oidcclient/pkce" ) const ( happyUpstreamIDPName = "upstream-idp-name" - upstreamIssuer = "https://my-upstream-issuer.com" - upstreamSubject = "abc123-some guid" // has a space character which should get escaped in URL - queryEscapedUpstreamSubject = "abc123-some+guid" - upstreamUsername = "test-pinniped-username" + oidcUpstreamIssuer = "https://my-upstream-issuer.com" + oidcUpstreamSubject = "abc123-some guid" // has a space character which should get escaped in URL + oidcUpstreamSubjectQueryEscaped = "abc123-some+guid" + oidcUpstreamUsername = "test-pinniped-username" - upstreamUsernameClaim = "the-user-claim" - upstreamGroupsClaim = "the-groups-claim" + oidcUpstreamUsernameClaim = "the-user-claim" + oidcUpstreamGroupsClaim = "the-groups-claim" happyUpstreamAuthcode = "upstream-auth-code" happyUpstreamRedirectURI = "https://example.com/callback" @@ -56,7 +55,7 @@ const ( ) var ( - upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} + oidcUpstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"} happyDownstreamScopesRequested = []string{"openid"} happyDownstreamScopesGranted = []string{"openid"} @@ -113,7 +112,7 @@ func TestCallbackEndpoint(t *testing.T) { tests := []struct { name string - idp oidctestutil.TestUpstreamOIDCIdentityProvider + idps *oidctestutil.UpstreamIDPListerBuilder method string path string csrfCookie string @@ -132,11 +131,11 @@ func TestCallbackEndpoint(t *testing.T) { wantDownstreamPKCEChallenge string wantDownstreamPKCEChallengeMethod string - wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs + wantAuthcodeExchangeCall *expectedAuthcodeExchange }{ { name: "GET with good state and cookie and successful upstream token exchange with response_mode=form_post returns 200 with HTML+JS form", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState( happyUpstreamStateParam().WithAuthorizeRequestParams( @@ -150,204 +149,254 @@ func TestCallbackEndpoint(t *testing.T) { wantStatus: http.StatusOK, wantContentType: "text/html;charset=UTF-8", wantBodyFormResponseRegexp: `(.+)`, - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamUsername, - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamUsername, - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups", - idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(), + name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, wantDownstreamIDTokenGroups: []string{}, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "upstream IDP configures username claim as special claim `email` and `email_verified` upstream claim is missing", - idp: happyUpstream().WithUsernameClaim("email"). - WithIDTokenClaim("email", "joe@whitehouse.gov").Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithUsernameClaim("email").WithIDTokenClaim("email", "joe@whitehouse.gov").Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, wantDownstreamIDTokenUsername: "joe@whitehouse.gov", - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "upstream IDP configures username claim as special claim `email` and `email_verified` upstream claim is present with true value", - idp: happyUpstream().WithUsernameClaim("email"). - WithIDTokenClaim("email", "joe@whitehouse.gov"). - WithIDTokenClaim("email_verified", true).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithUsernameClaim("email"). + WithIDTokenClaim("email", "joe@whitehouse.gov"). + WithIDTokenClaim("email_verified", true).Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, wantDownstreamIDTokenUsername: "joe@whitehouse.gov", - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "upstream IDP configures username claim as anything other than special claim `email` and `email_verified` upstream claim is present with false value", - idp: happyUpstream().WithUsernameClaim("some-claim"). - WithIDTokenClaim("some-claim", "joe"). - WithIDTokenClaim("email", "joe@whitehouse.gov"). - WithIDTokenClaim("email_verified", false).Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithUsernameClaim("some-claim"). + WithIDTokenClaim("some-claim", "joe"). + WithIDTokenClaim("email", "joe@whitehouse.gov"). + WithIDTokenClaim("email_verified", false).Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, // succeed despite `email_verified=false` because we're not using the email claim for anything wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, wantDownstreamIDTokenUsername: "joe", - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "upstream IDP configures username claim as special claim `email` and `email_verified` upstream claim is present with illegal value", - idp: happyUpstream().WithUsernameClaim("email"). + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().WithUsernameClaim("email"). WithIDTokenClaim("email", "joe@whitehouse.gov"). WithIDTokenClaim("email_verified", "supposed to be boolean").Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: email_verified claim in upstream ID token has invalid format\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: email_verified claim in upstream ID token has invalid format\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "upstream IDP configures username claim as special claim `email` and `email_verified` upstream claim is present with false value", - idp: happyUpstream().WithUsernameClaim("email"). - WithIDTokenClaim("email", "joe@whitehouse.gov"). - WithIDTokenClaim("email_verified", false).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: email_verified claim in upstream ID token has false value\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithUsernameClaim("email"). + WithIDTokenClaim("email", "joe@whitehouse.gov"). + WithIDTokenClaim("email_verified", false).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: email_verified claim in upstream ID token has false value\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for", - idp: happyUpstream().WithUsernameClaim("sub").Build(), + name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithUsernameClaim("sub").Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamSubject, - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamSubject, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream IDP's configured groups claim in the ID token has a non-array value", - idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, "notAnArrayGroup1 notAnArrayGroup2").Build(), + name: "upstream IDP's configured groups claim in the ID token has a non-array value", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim(oidcUpstreamGroupsClaim, "notAnArrayGroup1 notAnArrayGroup2").Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, wantDownstreamIDTokenGroups: []string{"notAnArrayGroup1 notAnArrayGroup2"}, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream IDP's configured groups claim in the ID token is a slice of interfaces", - idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, []interface{}{"group1", "group2"}).Build(), + name: "upstream IDP's configured groups claim in the ID token is a slice of interfaces", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim(oidcUpstreamGroupsClaim, []interface{}{"group1", "group2"}).Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, wantDownstreamIDTokenGroups: []string{"group1", "group2"}, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, // Pre-upstream-exchange verification { name: "PUT method is invalid", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodPut, path: newRequestPath().String(), wantStatus: http.StatusMethodNotAllowed, @@ -356,6 +405,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "POST method is invalid", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodPost, path: newRequestPath().String(), wantStatus: http.StatusMethodNotAllowed, @@ -364,6 +414,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "PATCH method is invalid", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodPatch, path: newRequestPath().String(), wantStatus: http.StatusMethodNotAllowed, @@ -372,6 +423,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "DELETE method is invalid", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodDelete, path: newRequestPath().String(), wantStatus: http.StatusMethodNotAllowed, @@ -380,6 +432,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "code param was not included on request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyState).WithoutCode().String(), csrfCookie: happyCSRFCookie, @@ -389,6 +442,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state param was not included on request", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithoutState().String(), csrfCookie: happyCSRFCookie, @@ -398,7 +452,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState("this-will-not-decode").String(), csrfCookie: happyCSRFCookie, @@ -410,22 +464,26 @@ func TestCallbackEndpoint(t *testing.T) { // This shouldn't happen in practice because the authorize endpoint should have already run the same // validations, but we would like to test the error handling in this endpoint anyway. name: "state param contains authorization request params which fail validation", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState( happyUpstreamStateParam(). WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyDownstreamRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()). Build(t, happyStateCodec), ).String(), - csrfCookie: happyCSRFCookie, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, - wantStatus: http.StatusInternalServerError, - wantContentType: htmlContentType, - wantBody: "Internal Server Error: error while generating and saving authcode\n", + csrfCookie: happyCSRFCookie, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, + + wantStatus: http.StatusInternalServerError, + wantContentType: htmlContentType, + wantBody: "Internal Server Error: error while generating and saving authcode\n", }, { name: "state's internal version does not match what we want", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyUpstreamStateParam().WithStateVersion("wrong-state-version").Build(t, happyStateCodec)).String(), csrfCookie: happyCSRFCookie, @@ -435,7 +493,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params element is invalid", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyUpstreamStateParam(). WithAuthorizeRequestParams("the following is an invalid url encoding token, and therefore this is an invalid param: %z"). @@ -447,7 +505,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params are missing required value (e.g., client_id)", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState( happyUpstreamStateParam(). @@ -461,7 +519,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "state's downstream auth params does not contain openid scope", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath(). WithState( @@ -472,18 +530,21 @@ func TestCallbackEndpoint(t *testing.T) { csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, - wantDownstreamIDTokenUsername: upstreamUsername, - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, wantDownstreamRequestedScopes: []string{"profile", "email"}, - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "state's downstream auth params also included offline_access scope", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath(). WithState( @@ -494,19 +555,22 @@ func TestCallbackEndpoint(t *testing.T) { csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid\+offline_access&state=` + happyDownstreamState, - wantDownstreamIDTokenUsername: upstreamUsername, - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, wantDownstreamRequestedScopes: []string{"openid", "offline_access"}, wantDownstreamGrantedScopes: []string{"openid", "offline_access"}, - wantDownstreamIDTokenGroups: upstreamGroupMembership, + wantDownstreamIDTokenGroups: oidcUpstreamGroupMembership, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { name: "the OIDCIdentityProvider CRD has been deleted", - idp: otherUpstreamOIDCIdentityProvider, + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&otherUpstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, @@ -516,7 +580,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "the CSRF cookie does not exist on request", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, @@ -525,7 +589,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", @@ -535,7 +599,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie csrf value does not match state csrf value", - idp: happyUpstream().Build(), + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(happyUpstream().Build()), method: http.MethodGet, path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(), csrfCookie: happyCSRFCookie, @@ -546,111 +610,156 @@ func TestCallbackEndpoint(t *testing.T) { // Upstream exchange { - name: "upstream auth code exchange fails", - idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadGateway, - wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", - wantContentType: htmlContentType, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream auth code exchange fails", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithUpstreamAuthcodeExchangeError(errors.New("some error")).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadGateway, + wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + wantContentType: htmlContentType, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token does not contain requested username claim", - idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", - wantContentType: htmlContentType, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token does not contain requested username claim", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithoutIDTokenClaim(oidcUpstreamUsernameClaim).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: no username claim in upstream ID token\n", + wantContentType: htmlContentType, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token does not contain requested groups claim", - idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(), + name: "upstream ID token does not contain requested groups claim", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithoutIDTokenClaim(oidcUpstreamGroupsClaim).Build(), + ), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusFound, wantRedirectLocationRegexp: happyDownstreamRedirectLocationRegexp, wantBody: "", - wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + queryEscapedUpstreamSubject, - wantDownstreamIDTokenUsername: upstreamUsername, + wantDownstreamIDTokenSubject: oidcUpstreamIssuer + "?sub=" + oidcUpstreamSubjectQueryEscaped, + wantDownstreamIDTokenUsername: oidcUpstreamUsername, wantDownstreamRequestedScopes: happyDownstreamScopesRequested, wantDownstreamGrantedScopes: happyDownstreamScopesGranted, wantDownstreamIDTokenGroups: []string{}, wantDownstreamNonce: downstreamNonce, wantDownstreamPKCEChallenge: downstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod: downstreamPKCEChallengeMethod, - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token contains username claim with weird format", - idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token contains username claim with weird format", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim(oidcUpstreamUsernameClaim, 42).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token does not contain iss claim when using default username claim config", - idp: happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token does not contain iss claim when using default username claim config", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token has an non-string iss claim when using default username claim config", - idp: happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token has an non-string iss claim when using default username claim config", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token contains groups claim with weird format", - idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token contains groups claim with weird format", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim(oidcUpstreamGroupsClaim, 42).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token contains groups claim where one element is invalid", - idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, []interface{}{"foo", 7}).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token contains groups claim where one element is invalid", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim(oidcUpstreamGroupsClaim, []interface{}{"foo", 7}).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, { - name: "upstream ID token contains groups claim with invalid null type", - idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, nil).Build(), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantContentType: htmlContentType, - wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", - wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + name: "upstream ID token contains groups claim with invalid null type", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + happyUpstream().WithIDTokenClaim(oidcUpstreamGroupsClaim, nil).Build(), + ), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantContentType: htmlContentType, + wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n", + wantAuthcodeExchangeCall: &expectedAuthcodeExchange{ + performedByUpstreamName: happyUpstreamIDPName, + args: happyExchangeAndValidateTokensArgs, + }, }, } for _, test := range tests { @@ -669,9 +778,9 @@ func TestCallbackEndpoint(t *testing.T) { jwksProviderIsUnused := jwks.NewDynamicJWKSProvider() oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecretFunc, jwksProviderIsUnused, timeoutsConfiguration) - idpLister := oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(&test.idp).Build() - subject := NewHandler(idpLister, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) - req := httptest.NewRequest(test.method, test.path, nil) + subject := NewHandler(test.idps.Build(), oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) + reqContext := context.WithValue(context.Background(), struct{ name string }{name: "test"}, "request-context") + req := httptest.NewRequest(test.method, test.path, nil).WithContext(reqContext) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) } @@ -682,12 +791,13 @@ func TestCallbackEndpoint(t *testing.T) { testutil.RequireSecurityHeaders(t, rsp) - if test.wantExchangeAndValidateTokensCall != nil { - require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) - test.wantExchangeAndValidateTokensCall.Ctx = req.Context() - require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0)) + if test.wantAuthcodeExchangeCall != nil { + test.wantAuthcodeExchangeCall.args.Ctx = reqContext + test.idps.RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens(t, + test.wantAuthcodeExchangeCall.performedByUpstreamName, test.wantAuthcodeExchangeCall.args, + ) } else { - require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + test.idps.RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t) } require.Equal(t, test.wantStatus, rsp.Code) @@ -749,6 +859,11 @@ func TestCallbackEndpoint(t *testing.T) { } } +type expectedAuthcodeExchange struct { + performedByUpstreamName string + args *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs +} + type requestPath struct { code, state *string } @@ -838,70 +953,20 @@ func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamSt return b } -type upstreamOIDCIdentityProviderBuilder struct { - idToken map[string]interface{} - usernameClaim, groupsClaim string - authcodeExchangeErr error -} - -func happyUpstream() *upstreamOIDCIdentityProviderBuilder { - return &upstreamOIDCIdentityProviderBuilder{ - usernameClaim: upstreamUsernameClaim, - groupsClaim: upstreamGroupsClaim, - idToken: map[string]interface{}{ - "iss": upstreamIssuer, - "sub": upstreamSubject, - upstreamUsernameClaim: upstreamUsername, - upstreamGroupsClaim: upstreamGroupMembership, - "other-claim": "should be ignored", - }, - } -} - -func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(value string) *upstreamOIDCIdentityProviderBuilder { - u.usernameClaim = value - return u -} - -func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder { - u.usernameClaim = "" - return u -} - -func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder { - u.groupsClaim = "" - return u -} - -func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *upstreamOIDCIdentityProviderBuilder { - u.idToken[name] = value - return u -} - -func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder { - delete(u.idToken, claim) - return u -} - -func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder { - u.authcodeExchangeErr = err - return u -} - -func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamOIDCIdentityProvider { - return oidctestutil.TestUpstreamOIDCIdentityProvider{ - Name: happyUpstreamIDPName, - ClientID: "some-client-id", - UsernameClaim: u.usernameClaim, - GroupsClaim: u.groupsClaim, - Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier oidcpkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { - if u.authcodeExchangeErr != nil { - return nil, u.authcodeExchangeErr - } - return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}}, nil - }, - } +func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { + return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). + WithName(happyUpstreamIDPName). + WithClientID("some-client-id"). + WithScopes([]string{"scope1", "scope2"}). + WithUsernameClaim(oidcUpstreamUsernameClaim). + WithGroupsClaim(oidcUpstreamGroupsClaim). + WithIDTokenClaim("iss", oidcUpstreamIssuer). + WithIDTokenClaim("sub", oidcUpstreamSubject). + WithIDTokenClaim(oidcUpstreamUsernameClaim, oidcUpstreamUsername). + WithIDTokenClaim(oidcUpstreamGroupsClaim, oidcUpstreamGroupMembership). + WithIDTokenClaim("other-claim", "should be ignored"). + WithAllowPasswordGrant(false). + WithPasswordGrantError(errors.New("the callback endpoint should not use password grants")) } func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values { diff --git a/internal/testutil/oidctestutil/oidctestutil.go b/internal/testutil/oidctestutil/oidctestutil.go index de5aefeb..7591bbca 100644 --- a/internal/testutil/oidctestutil/oidctestutil.go +++ b/internal/testutil/oidctestutil/oidctestutil.go @@ -49,6 +49,14 @@ type ExchangeAuthcodeAndValidateTokenArgs struct { RedirectURI string } +// PasswordCredentialsGrantAndValidateTokensArgs is used to spy on calls to +// TestUpstreamOIDCIdentityProvider.PasswordCredentialsGrantAndValidateTokensFunc(). +type PasswordCredentialsGrantAndValidateTokensArgs struct { + Ctx context.Context + Username string + Password string +} + type TestUpstreamLDAPIdentityProvider struct { Name string URL *url.URL @@ -70,13 +78,14 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL { } type TestUpstreamOIDCIdentityProvider struct { - Name string - ClientID string - AuthorizationURL url.URL - UsernameClaim string - GroupsClaim string - Scopes []string - AllowPasswordGrant bool + Name string + ClientID string + AuthorizationURL url.URL + UsernameClaim string + GroupsClaim string + Scopes []string + AllowPasswordGrant bool + ExchangeAuthcodeAndValidateTokensFunc func( ctx context.Context, authcode string, @@ -84,8 +93,16 @@ type TestUpstreamOIDCIdentityProvider struct { expectedIDTokenNonce nonce.Nonce, ) (*oidctypes.Token, error) - exchangeAuthcodeAndValidateTokensCallCount int - exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs + PasswordCredentialsGrantAndValidateTokensFunc func( + ctx context.Context, + username string, + password string, + ) (*oidctypes.Token, error) + + exchangeAuthcodeAndValidateTokensCallCount int + exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs + passwordCredentialsGrantAndValidateTokensCallCount int + passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs } func (u *TestUpstreamOIDCIdentityProvider) GetName() string { @@ -117,8 +134,16 @@ func (u *TestUpstreamOIDCIdentityProvider) AllowsPasswordGrant() bool { } func (u *TestUpstreamOIDCIdentityProvider) PasswordCredentialsGrantAndValidateTokens(ctx context.Context, username, password string) (*oidctypes.Token, error) { - // TODO implement this unit test helper - return nil, nil + if u.passwordCredentialsGrantAndValidateTokensArgs == nil { + u.passwordCredentialsGrantAndValidateTokensArgs = make([]*PasswordCredentialsGrantAndValidateTokensArgs, 0) + } + u.passwordCredentialsGrantAndValidateTokensCallCount++ + u.passwordCredentialsGrantAndValidateTokensArgs = append(u.passwordCredentialsGrantAndValidateTokensArgs, &PasswordCredentialsGrantAndValidateTokensArgs{ + Ctx: ctx, + Username: username, + Password: password, + }) + return u.PasswordCredentialsGrantAndValidateTokensFunc(ctx, username, password) } func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( @@ -190,10 +215,193 @@ func (b *UpstreamIDPListerBuilder) Build() provider.DynamicUpstreamIDPProvider { return idpProvider } +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPasswordCredentialsGrantAndValidateTokens( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *PasswordCredentialsGrantAndValidateTokensArgs, +) { + t.Helper() + var actualArgs *PasswordCredentialsGrantAndValidateTokensArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstreamOIDC.passwordCredentialsGrantAndValidateTokensCallCount + actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualArgs = upstreamOIDC.passwordCredentialsGrantAndValidateTokensArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, + "should have been exactly one call to PasswordCredentialsGrantAndValidateTokens() by all OIDC upstreams", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "PasswordCredentialsGrantAndValidateTokens() was called on the wrong OIDC upstream", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPasswordCredentialsGrantAndValidateTokens(t *testing.T) { + t.Helper() + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.passwordCredentialsGrantAndValidateTokensCallCount + } + require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + "expected exactly zero calls to PasswordCredentialsGrantAndValidateTokens()", + ) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens( + t *testing.T, + expectedPerformedByUpstreamName string, + expectedArgs *ExchangeAuthcodeAndValidateTokenArgs, +) { + t.Helper() + var actualArgs *ExchangeAuthcodeAndValidateTokenArgs + var actualNameOfUpstreamWhichMadeCall string + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + callCountOnThisUpstream := upstreamOIDC.exchangeAuthcodeAndValidateTokensCallCount + actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream + if callCountOnThisUpstream == 1 { + actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name + actualArgs = upstreamOIDC.exchangeAuthcodeAndValidateTokensArgs[0] + } + } + require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, + "should have been exactly one call to ExchangeAuthcodeAndValidateTokens() by all OIDC upstreams", + ) + require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, + "ExchangeAuthcodeAndValidateTokens() was called on the wrong OIDC upstream", + ) + require.Equal(t, expectedArgs, actualArgs) +} + +func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t *testing.T) { + t.Helper() + actualCallCountAcrossAllOIDCUpstreams := 0 + for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { + actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.exchangeAuthcodeAndValidateTokensCallCount + } + require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, + "expected exactly zero calls to ExchangeAuthcodeAndValidateTokens()", + ) +} + func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder { return &UpstreamIDPListerBuilder{} } +type TestUpstreamOIDCIdentityProviderBuilder struct { + name string + clientID string + scopes []string + idToken map[string]interface{} + usernameClaim string + groupsClaim string + authorizationURL url.URL + allowPasswordGrant bool + authcodeExchangeErr error + passwordGrantErr error +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder { + u.name = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithClientID(value string) *TestUpstreamOIDCIdentityProviderBuilder { + u.clientID = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAuthorizationURL(value url.URL) *TestUpstreamOIDCIdentityProviderBuilder { + u.authorizationURL = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAllowPasswordGrant(value bool) *TestUpstreamOIDCIdentityProviderBuilder { + u.allowPasswordGrant = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithScopes(values []string) *TestUpstreamOIDCIdentityProviderBuilder { + u.scopes = values + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithUsernameClaim(value string) *TestUpstreamOIDCIdentityProviderBuilder { + u.usernameClaim = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *TestUpstreamOIDCIdentityProviderBuilder { + u.usernameClaim = "" + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithGroupsClaim(value string) *TestUpstreamOIDCIdentityProviderBuilder { + u.groupsClaim = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *TestUpstreamOIDCIdentityProviderBuilder { + u.groupsClaim = "" + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *TestUpstreamOIDCIdentityProviderBuilder { + if u.idToken == nil { + u.idToken = map[string]interface{}{} + } + u.idToken[name] = value + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *TestUpstreamOIDCIdentityProviderBuilder { + delete(u.idToken, claim) + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithUpstreamAuthcodeExchangeError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.authcodeExchangeErr = err + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPasswordGrantError(err error) *TestUpstreamOIDCIdentityProviderBuilder { + u.passwordGrantErr = err + return u +} + +func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdentityProvider { + return &TestUpstreamOIDCIdentityProvider{ + Name: u.name, + ClientID: u.clientID, + UsernameClaim: u.usernameClaim, + GroupsClaim: u.groupsClaim, + Scopes: u.scopes, + AllowPasswordGrant: u.allowPasswordGrant, + AuthorizationURL: u.authorizationURL, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.authcodeExchangeErr != nil { + return nil, u.authcodeExchangeErr + } + return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}}, nil + }, + PasswordCredentialsGrantAndValidateTokensFunc: func(ctx context.Context, username, password string) (*oidctypes.Token, error) { + if u.passwordGrantErr != nil { + return nil, u.passwordGrantErr + } + return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}}, nil + }, + } +} + +func NewTestUpstreamOIDCIdentityProviderBuilder() *TestUpstreamOIDCIdentityProviderBuilder { + return &TestUpstreamOIDCIdentityProviderBuilder{} +} + // Declare a separate type from the production code to ensure that the state param's contents was serialized // in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of // the serialized fields is the same, which doesn't really matter expect that we can make simpler equality