callback_handler.go: assert correct args are passed to token exchange
Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
parent
48e0250649
commit
2e62be3ebb
@ -271,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "happy path using GET without a CSRF cookie",
|
name: "happy path using GET without a CSRF cookie",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -289,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "happy path using GET with a CSRF cookie",
|
name: "happy path using GET with a CSRF cookie",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -307,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "happy path using POST",
|
name: "happy path using POST",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -327,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
|
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -349,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "downstream redirect uri does not match what is configured for client",
|
name: "downstream redirect uri does not match what is configured for client",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -366,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "downstream client does not exist",
|
name: "downstream client does not exist",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -381,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "response type is unsupported",
|
name: "response type is unsupported",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -397,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "downstream scopes do not match what is configured for client",
|
name: "downstream scopes do not match what is configured for client",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -413,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing response type in request",
|
name: "missing response type in request",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -429,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing client id in request",
|
name: "missing client id in request",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -444,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -460,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
|
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -476,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
|
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -492,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -510,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
// through that part of the fosite library.
|
// through that part of the fosite library.
|
||||||
name: "prompt param is not allowed to have none and another legal value at the same time",
|
name: "prompt param is not allowed to have none and another legal value at the same time",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -526,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "OIDC validations are skipped when the openid scope was not requested",
|
name: "OIDC validations are skipped when the openid scope was not requested",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -547,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "state does not have enough entropy",
|
name: "state does not have enough entropy",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -563,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "error while encoding upstream state param",
|
name: "error while encoding upstream state param",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -578,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "error while encoding CSRF cookie value for new cookie",
|
name: "error while encoding CSRF cookie value for new cookie",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -593,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "error while generating CSRF token",
|
name: "error while generating CSRF token",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
|
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -608,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "error while generating nonce",
|
name: "error while generating nonce",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
|
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
|
||||||
@ -623,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "error while generating PKCE",
|
name: "error while generating PKCE",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
|
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -638,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "error while decoding CSRF cookie",
|
name: "error while decoding CSRF cookie",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
generateCSRF: happyCSRFGenerator,
|
generateCSRF: happyCSRFGenerator,
|
||||||
generatePKCE: happyPKCEGenerator,
|
generatePKCE: happyPKCEGenerator,
|
||||||
generateNonce: happyNonceGenerator,
|
generateNonce: happyNonceGenerator,
|
||||||
@ -664,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "too many upstream providers are configured",
|
name: "too many upstream providers are configured",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: happyGetRequestPath,
|
path: happyGetRequestPath,
|
||||||
wantStatus: http.StatusUnprocessableEntity,
|
wantStatus: http.StatusUnprocessableEntity,
|
||||||
@ -674,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "PUT is a bad method",
|
name: "PUT is a bad method",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
method: http.MethodPut,
|
method: http.MethodPut,
|
||||||
path: "/some/path",
|
path: "/some/path",
|
||||||
wantStatus: http.StatusMethodNotAllowed,
|
wantStatus: http.StatusMethodNotAllowed,
|
||||||
@ -684,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "PATCH is a bad method",
|
name: "PATCH is a bad method",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
method: http.MethodPatch,
|
method: http.MethodPatch,
|
||||||
path: "/some/path",
|
path: "/some/path",
|
||||||
wantStatus: http.StatusMethodNotAllowed,
|
wantStatus: http.StatusMethodNotAllowed,
|
||||||
@ -694,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "DELETE is a bad method",
|
name: "DELETE is a bad method",
|
||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
|
||||||
method: http.MethodDelete,
|
method: http.MethodDelete,
|
||||||
path: "/some/path",
|
path: "/some/path",
|
||||||
wantStatus: http.StatusMethodNotAllowed,
|
wantStatus: http.StatusMethodNotAllowed,
|
||||||
|
@ -47,13 +47,14 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Grant the openid scope only if it was requested.
|
// Grant the openid scope only if it was requested.
|
||||||
|
// TODO: shouldn't we be potentially granting more scopes than just openid...
|
||||||
grantOpenIDScopeIfRequested(authorizeRequester)
|
grantOpenIDScopeIfRequested(authorizeRequester)
|
||||||
|
|
||||||
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
|
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
|
||||||
r.Context(),
|
r.Context(),
|
||||||
"TODO", // TODO use the upstream authcode (code param) here
|
r.URL.Query().Get("code"), // TODO: do we need to validate this?
|
||||||
"TODO", // TODO use the pkce value from the decoded state param here
|
state.PKCECode,
|
||||||
"TODO", // TODO use the nonce value from the decoded state param here
|
state.Nonce,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
|
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
|
||||||
|
@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/oidc"
|
"go.pinniped.dev/internal/oidc"
|
||||||
"go.pinniped.dev/internal/oidc/provider"
|
|
||||||
"go.pinniped.dev/internal/oidcclient"
|
"go.pinniped.dev/internal/oidcclient"
|
||||||
"go.pinniped.dev/internal/oidcclient/nonce"
|
"go.pinniped.dev/internal/oidcclient/nonce"
|
||||||
"go.pinniped.dev/internal/oidcclient/pkce"
|
"go.pinniped.dev/internal/oidcclient/pkce"
|
||||||
@ -35,6 +34,8 @@ const (
|
|||||||
func TestCallbackEndpoint(t *testing.T) {
|
func TestCallbackEndpoint(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
downstreamRedirectURI = "http://127.0.0.1/callback"
|
downstreamRedirectURI = "http://127.0.0.1/callback"
|
||||||
|
|
||||||
|
happyUpstreamAuthcode = "upstream-auth-code"
|
||||||
)
|
)
|
||||||
|
|
||||||
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
|
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
|
||||||
@ -170,10 +171,16 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
|
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
|
||||||
|
|
||||||
|
happyExchangeAndValidateTokensArgs := &testutil.ExchangeAuthcodeAndValidateTokenArgs{
|
||||||
|
Authcode: happyUpstreamAuthcode,
|
||||||
|
PKCECodeVerifier: pkce.Code(happyPKCE),
|
||||||
|
ExpectedIDTokenNonce: nonce.Nonce(happyNonce),
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
||||||
idpListGetter provider.DynamicUpstreamIDPProvider
|
idp testutil.TestUpstreamOIDCIdentityProvider
|
||||||
method string
|
method string
|
||||||
path string
|
path string
|
||||||
csrfCookie string
|
csrfCookie string
|
||||||
@ -184,12 +191,14 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
// TODO: I am unused...
|
// TODO: I am unused...
|
||||||
wantAuthcodeStored bool
|
wantAuthcodeStored bool
|
||||||
wantGrantedOpenidScope bool
|
wantGrantedOpenidScope bool
|
||||||
|
|
||||||
|
wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
|
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(happyState).String(),
|
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
wantStatus: http.StatusFound,
|
wantStatus: http.StatusFound,
|
||||||
// Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it
|
// Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it
|
||||||
@ -197,6 +206,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
wantAuthcodeStored: true,
|
wantAuthcodeStored: true,
|
||||||
wantGrantedOpenidScope: true,
|
wantGrantedOpenidScope: true,
|
||||||
wantBody: "",
|
wantBody: "",
|
||||||
|
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||||
},
|
},
|
||||||
// TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes)
|
// TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes)
|
||||||
|
|
||||||
@ -247,7 +257,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState("this-will-not-decode").String(),
|
path: newRequestPath().WithState("this-will-not-decode").String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
@ -256,7 +266,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "state's internal version does not match what we want",
|
name: "state's internal version does not match what we want",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(wrongVersionState).String(),
|
path: newRequestPath().WithState(wrongVersionState).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
@ -265,7 +275,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "state's downstream auth params element is invalid",
|
name: "state's downstream auth params element is invalid",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(),
|
path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
@ -274,7 +284,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "state's downstream auth params are missing required value (e.g., client_id)",
|
name: "state's downstream auth params are missing required value (e.g., client_id)",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(missingClientIDState).String(),
|
path: newRequestPath().WithState(missingClientIDState).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
@ -283,16 +293,17 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "state's downstream auth params does not contain openid scope",
|
name: "state's downstream auth params does not contain openid scope",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(noOpenidScopeState).String(),
|
path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
wantStatus: http.StatusFound,
|
wantStatus: http.StatusFound,
|
||||||
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
|
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
|
||||||
|
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "the UpstreamOIDCProvider CRD has been deleted",
|
name: "the UpstreamOIDCProvider CRD has been deleted",
|
||||||
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider),
|
idp: otherUpstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(happyState).String(),
|
path: newRequestPath().WithState(happyState).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
@ -301,7 +312,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "the CSRF cookie does not exist on request",
|
name: "the CSRF cookie does not exist on request",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(happyState).String(),
|
path: newRequestPath().WithState(happyState).String(),
|
||||||
wantStatus: http.StatusForbidden,
|
wantStatus: http.StatusForbidden,
|
||||||
@ -309,7 +320,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(happyState).String(),
|
path: newRequestPath().WithState(happyState).String(),
|
||||||
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
|
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
|
||||||
@ -318,7 +329,7 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "cookie csrf value does not match state csrf value",
|
name: "cookie csrf value does not match state csrf value",
|
||||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
idp: upstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(wrongCSRFValueState).String(),
|
path: newRequestPath().WithState(wrongCSRFValueState).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
@ -329,12 +340,13 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
// Upstream exchange
|
// Upstream exchange
|
||||||
{
|
{
|
||||||
name: "upstream auth code exchange fails",
|
name: "upstream auth code exchange fails",
|
||||||
idpListGetter: testutil.NewIDPListGetter(failedExchangeUpstreamOIDCIdentityProvider),
|
idp: failedExchangeUpstreamOIDCIdentityProvider,
|
||||||
method: http.MethodGet,
|
method: http.MethodGet,
|
||||||
path: newRequestPath().WithState(happyState).String(),
|
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
|
||||||
csrfCookie: happyCSRFCookie,
|
csrfCookie: happyCSRFCookie,
|
||||||
wantStatus: http.StatusBadGateway,
|
wantStatus: http.StatusBadGateway,
|
||||||
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
|
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
|
||||||
|
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
@ -352,7 +364,8 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
|
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
|
||||||
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
|
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
|
||||||
|
|
||||||
subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec)
|
idpListGetter := testutil.NewIDPListGetter(&test.idp)
|
||||||
|
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec)
|
||||||
req := httptest.NewRequest(test.method, test.path, nil)
|
req := httptest.NewRequest(test.method, test.path, nil)
|
||||||
if test.csrfCookie != "" {
|
if test.csrfCookie != "" {
|
||||||
req.Header.Set("Cookie", test.csrfCookie)
|
req.Header.Set("Cookie", test.csrfCookie)
|
||||||
@ -408,6 +421,14 @@ func TestCallbackEndpoint(t *testing.T) {
|
|||||||
} else {
|
} else {
|
||||||
require.Empty(t, rsp.Header().Values("Location"))
|
require.Empty(t, rsp.Header().Values("Location"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if test.wantExchangeAndValidateTokensCall != nil {
|
||||||
|
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
|
||||||
|
test.wantExchangeAndValidateTokensCall.Ctx = req.Context()
|
||||||
|
require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0))
|
||||||
|
} else {
|
||||||
|
require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ func TestManager(t *testing.T) {
|
|||||||
|
|
||||||
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
|
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
|
||||||
r.NoError(err)
|
r.NoError(err)
|
||||||
idpListGetter := testutil.NewIDPListGetter(testutil.TestUpstreamOIDCIdentityProvider{
|
idpListGetter := testutil.NewIDPListGetter(&testutil.TestUpstreamOIDCIdentityProvider{
|
||||||
Name: "test-idp",
|
Name: "test-idp",
|
||||||
ClientID: "test-client-id",
|
ClientID: "test-client-id",
|
||||||
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
|
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
|
||||||
|
@ -15,6 +15,15 @@ import (
|
|||||||
|
|
||||||
// Test helpers for the OIDC package.
|
// Test helpers for the OIDC package.
|
||||||
|
|
||||||
|
// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to
|
||||||
|
// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc().
|
||||||
|
type ExchangeAuthcodeAndValidateTokenArgs struct {
|
||||||
|
Ctx context.Context
|
||||||
|
Authcode string
|
||||||
|
PKCECodeVerifier pkce.Code
|
||||||
|
ExpectedIDTokenNonce nonce.Nonce
|
||||||
|
}
|
||||||
|
|
||||||
type TestUpstreamOIDCIdentityProvider struct {
|
type TestUpstreamOIDCIdentityProvider struct {
|
||||||
Name string
|
Name string
|
||||||
ClientID string
|
ClientID string
|
||||||
@ -28,6 +37,9 @@ type TestUpstreamOIDCIdentityProvider struct {
|
|||||||
pkceCodeVerifier pkce.Code,
|
pkceCodeVerifier pkce.Code,
|
||||||
expectedIDTokenNonce nonce.Nonce,
|
expectedIDTokenNonce nonce.Nonce,
|
||||||
) (oidcclient.Token, map[string]interface{}, error)
|
) (oidcclient.Token, map[string]interface{}, error)
|
||||||
|
|
||||||
|
exchangeAuthcodeAndValidateTokensCallCount int
|
||||||
|
exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
|
func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
|
||||||
@ -60,14 +72,35 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
|
|||||||
pkceCodeVerifier pkce.Code,
|
pkceCodeVerifier pkce.Code,
|
||||||
expectedIDTokenNonce nonce.Nonce,
|
expectedIDTokenNonce nonce.Nonce,
|
||||||
) (oidcclient.Token, map[string]interface{}, error) {
|
) (oidcclient.Token, map[string]interface{}, error) {
|
||||||
|
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
|
||||||
|
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
|
||||||
|
}
|
||||||
|
u.exchangeAuthcodeAndValidateTokensCallCount++
|
||||||
|
u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{
|
||||||
|
Ctx: ctx,
|
||||||
|
Authcode: authcode,
|
||||||
|
PKCECodeVerifier: pkceCodeVerifier,
|
||||||
|
ExpectedIDTokenNonce: expectedIDTokenNonce,
|
||||||
|
})
|
||||||
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
|
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIDPListGetter(upstreamOIDCIdentityProviders ...TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
|
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int {
|
||||||
|
return u.exchangeAuthcodeAndValidateTokensCallCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs {
|
||||||
|
if u.exchangeAuthcodeAndValidateTokensArgs == nil {
|
||||||
|
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
|
||||||
|
}
|
||||||
|
return u.exchangeAuthcodeAndValidateTokensArgs[call]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
|
||||||
idpProvider := provider.NewDynamicUpstreamIDPProvider()
|
idpProvider := provider.NewDynamicUpstreamIDPProvider()
|
||||||
upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders))
|
upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders))
|
||||||
for i := range upstreamOIDCIdentityProviders {
|
for i := range upstreamOIDCIdentityProviders {
|
||||||
upstreams[i] = provider.UpstreamOIDCIdentityProviderI(&upstreamOIDCIdentityProviders[i])
|
upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i])
|
||||||
}
|
}
|
||||||
idpProvider.SetIDPList(upstreams)
|
idpProvider.SetIDPList(upstreams)
|
||||||
return idpProvider
|
return idpProvider
|
||||||
|
Loading…
Reference in New Issue
Block a user