callback_handler.go: assert correct args are passed to token exchange

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-19 10:20:46 -05:00
parent 48e0250649
commit 2e62be3ebb
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
5 changed files with 173 additions and 118 deletions

View File

@ -271,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path using GET without a CSRF cookie", name: "happy path using GET without a CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -289,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path using GET with a CSRF cookie", name: "happy path using GET with a CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -307,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path using POST", name: "happy path using POST",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -327,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path when downstream redirect uri matches what is configured for client except for the port number", name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -349,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "downstream redirect uri does not match what is configured for client", name: "downstream redirect uri does not match what is configured for client",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -366,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "downstream client does not exist", name: "downstream client does not exist",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -381,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "response type is unsupported", name: "response type is unsupported",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -397,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "downstream scopes do not match what is configured for client", name: "downstream scopes do not match what is configured for client",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -413,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing response type in request", name: "missing response type in request",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -429,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing client id in request", name: "missing client id in request",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -444,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -460,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -476,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -492,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -510,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
// through that part of the fosite library. // through that part of the fosite library.
name: "prompt param is not allowed to have none and another legal value at the same time", name: "prompt param is not allowed to have none and another legal value at the same time",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -526,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "OIDC validations are skipped when the openid scope was not requested", name: "OIDC validations are skipped when the openid scope was not requested",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -547,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "state does not have enough entropy", name: "state does not have enough entropy",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -563,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while encoding upstream state param", name: "error while encoding upstream state param",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -578,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while encoding CSRF cookie value for new cookie", name: "error while encoding CSRF cookie value for new cookie",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -593,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while generating CSRF token", name: "error while generating CSRF token",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -608,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while generating nonce", name: "error while generating nonce",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
@ -623,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while generating PKCE", name: "error while generating PKCE",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -638,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while decoding CSRF cookie", name: "error while decoding CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -664,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "too many upstream providers are configured", name: "too many upstream providers are configured",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity, wantStatus: http.StatusUnprocessableEntity,
@ -674,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "PUT is a bad method", name: "PUT is a bad method",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPut, method: http.MethodPut,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -684,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "PATCH is a bad method", name: "PATCH is a bad method",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPatch, method: http.MethodPatch,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -694,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "DELETE is a bad method", name: "DELETE is a bad method",
issuer: issuer, issuer: issuer,
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodDelete, method: http.MethodDelete,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,

View File

@ -47,13 +47,14 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi
} }
// Grant the openid scope only if it was requested. // Grant the openid scope only if it was requested.
// TODO: shouldn't we be potentially granting more scopes than just openid...
grantOpenIDScopeIfRequested(authorizeRequester) grantOpenIDScopeIfRequested(authorizeRequester)
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
r.Context(), r.Context(),
"TODO", // TODO use the upstream authcode (code param) here r.URL.Query().Get("code"), // TODO: do we need to validate this?
"TODO", // TODO use the pkce value from the decoded state param here state.PKCECode,
"TODO", // TODO use the nonce value from the decoded state param here state.Nonce,
) )
if err != nil { if err != nil {
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")

View File

@ -21,7 +21,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient" "go.pinniped.dev/internal/oidcclient"
"go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/oidcclient/pkce"
@ -35,6 +34,8 @@ const (
func TestCallbackEndpoint(t *testing.T) { func TestCallbackEndpoint(t *testing.T) {
const ( const (
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
happyUpstreamAuthcode = "upstream-auth-code"
) )
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
@ -170,10 +171,16 @@ func TestCallbackEndpoint(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
happyExchangeAndValidateTokensArgs := &testutil.ExchangeAuthcodeAndValidateTokenArgs{
Authcode: happyUpstreamAuthcode,
PKCECodeVerifier: pkce.Code(happyPKCE),
ExpectedIDTokenNonce: nonce.Nonce(happyNonce),
}
tests := []struct { tests := []struct {
name string name string
idpListGetter provider.DynamicUpstreamIDPProvider idp testutil.TestUpstreamOIDCIdentityProvider
method string method string
path string path string
csrfCookie string csrfCookie string
@ -184,12 +191,14 @@ func TestCallbackEndpoint(t *testing.T) {
// TODO: I am unused... // TODO: I am unused...
wantAuthcodeStored bool wantAuthcodeStored bool
wantGrantedOpenidScope bool wantGrantedOpenidScope bool
wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs
}{ }{
{ {
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(), path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
// Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it
@ -197,6 +206,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantAuthcodeStored: true, wantAuthcodeStored: true,
wantGrantedOpenidScope: true, wantGrantedOpenidScope: true,
wantBody: "", wantBody: "",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
}, },
// TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes)
@ -247,7 +257,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState("this-will-not-decode").String(), path: newRequestPath().WithState("this-will-not-decode").String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
@ -256,7 +266,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "state's internal version does not match what we want", name: "state's internal version does not match what we want",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(wrongVersionState).String(), path: newRequestPath().WithState(wrongVersionState).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
@ -265,7 +275,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "state's downstream auth params element is invalid", name: "state's downstream auth params element is invalid",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
@ -274,7 +284,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "state's downstream auth params are missing required value (e.g., client_id)", name: "state's downstream auth params are missing required value (e.g., client_id)",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(missingClientIDState).String(), path: newRequestPath().WithState(missingClientIDState).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
@ -283,16 +293,17 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "state's downstream auth params does not contain openid scope", name: "state's downstream auth params does not contain openid scope",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(noOpenidScopeState).String(), path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
}, },
{ {
name: "the UpstreamOIDCProvider CRD has been deleted", name: "the UpstreamOIDCProvider CRD has been deleted",
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), idp: otherUpstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(), path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
@ -301,7 +312,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "the CSRF cookie does not exist on request", name: "the CSRF cookie does not exist on request",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(), path: newRequestPath().WithState(happyState).String(),
wantStatus: http.StatusForbidden, wantStatus: http.StatusForbidden,
@ -309,7 +320,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(), path: newRequestPath().WithState(happyState).String(),
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
@ -318,7 +329,7 @@ func TestCallbackEndpoint(t *testing.T) {
}, },
{ {
name: "cookie csrf value does not match state csrf value", name: "cookie csrf value does not match state csrf value",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), idp: upstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(wrongCSRFValueState).String(), path: newRequestPath().WithState(wrongCSRFValueState).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
@ -329,12 +340,13 @@ func TestCallbackEndpoint(t *testing.T) {
// Upstream exchange // Upstream exchange
{ {
name: "upstream auth code exchange fails", name: "upstream auth code exchange fails",
idpListGetter: testutil.NewIDPListGetter(failedExchangeUpstreamOIDCIdentityProvider), idp: failedExchangeUpstreamOIDCIdentityProvider,
method: http.MethodGet, method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(), path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadGateway, wantStatus: http.StatusBadGateway,
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
}, },
} }
for _, test := range tests { for _, test := range tests {
@ -352,7 +364,8 @@ func TestCallbackEndpoint(t *testing.T) {
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) idpListGetter := testutil.NewIDPListGetter(&test.idp)
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec)
req := httptest.NewRequest(test.method, test.path, nil) req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" { if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", test.csrfCookie)
@ -408,6 +421,14 @@ func TestCallbackEndpoint(t *testing.T) {
} else { } else {
require.Empty(t, rsp.Header().Values("Location")) require.Empty(t, rsp.Header().Values("Location"))
} }
if test.wantExchangeAndValidateTokensCall != nil {
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
test.wantExchangeAndValidateTokensCall.Ctx = req.Context()
require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0))
} else {
require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
}
}) })
} }
} }

View File

@ -109,7 +109,7 @@ func TestManager(t *testing.T) {
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
r.NoError(err) r.NoError(err)
idpListGetter := testutil.NewIDPListGetter(testutil.TestUpstreamOIDCIdentityProvider{ idpListGetter := testutil.NewIDPListGetter(&testutil.TestUpstreamOIDCIdentityProvider{
Name: "test-idp", Name: "test-idp",
ClientID: "test-client-id", ClientID: "test-client-id",
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,

View File

@ -15,6 +15,15 @@ import (
// Test helpers for the OIDC package. // Test helpers for the OIDC package.
// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to
// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc().
type ExchangeAuthcodeAndValidateTokenArgs struct {
Ctx context.Context
Authcode string
PKCECodeVerifier pkce.Code
ExpectedIDTokenNonce nonce.Nonce
}
type TestUpstreamOIDCIdentityProvider struct { type TestUpstreamOIDCIdentityProvider struct {
Name string Name string
ClientID string ClientID string
@ -28,6 +37,9 @@ type TestUpstreamOIDCIdentityProvider struct {
pkceCodeVerifier pkce.Code, pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce, expectedIDTokenNonce nonce.Nonce,
) (oidcclient.Token, map[string]interface{}, error) ) (oidcclient.Token, map[string]interface{}, error)
exchangeAuthcodeAndValidateTokensCallCount int
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
} }
func (u *TestUpstreamOIDCIdentityProvider) GetName() string { func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
@ -60,14 +72,35 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
pkceCodeVerifier pkce.Code, pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce, expectedIDTokenNonce nonce.Nonce,
) (oidcclient.Token, map[string]interface{}, error) { ) (oidcclient.Token, map[string]interface{}, error) {
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
}
u.exchangeAuthcodeAndValidateTokensCallCount++
u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{
Ctx: ctx,
Authcode: authcode,
PKCECodeVerifier: pkceCodeVerifier,
ExpectedIDTokenNonce: expectedIDTokenNonce,
})
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
} }
func NewIDPListGetter(upstreamOIDCIdentityProviders ...TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int {
return u.exchangeAuthcodeAndValidateTokensCallCount
}
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs {
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
}
return u.exchangeAuthcodeAndValidateTokensArgs[call]
}
func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider() idpProvider := provider.NewDynamicUpstreamIDPProvider()
upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders)) upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders))
for i := range upstreamOIDCIdentityProviders { for i := range upstreamOIDCIdentityProviders {
upstreams[i] = provider.UpstreamOIDCIdentityProviderI(&upstreamOIDCIdentityProviders[i]) upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i])
} }
idpProvider.SetIDPList(upstreams) idpProvider.SetIDPList(upstreams)
return idpProvider return idpProvider