callback_handler.go: Add JWT Issuer claim to storage

This commit is contained in:
Ryan Richard 2020-11-19 08:35:23 -08:00
parent ace861f722
commit ee84f31f42
4 changed files with 52 additions and 53 deletions

View File

@ -24,7 +24,7 @@ import (
) )
func NewHandler( func NewHandler(
issuer string, downstreamIssuer string,
idpListGetter oidc.IDPListGetter, idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
@ -92,7 +92,7 @@ func NewHandler(
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: upstreamIDP.GetAuthorizationURL().String(), AuthURL: upstreamIDP.GetAuthorizationURL().String(),
}, },
RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.GetName()), RedirectURL: fmt.Sprintf("%s/callback/%s", downstreamIssuer, upstreamIDP.GetName()),
Scopes: upstreamIDP.GetScopes(), Scopes: upstreamIDP.GetScopes(),
} }

View File

@ -28,6 +28,7 @@ import (
func TestAuthorizationEndpoint(t *testing.T) { func TestAuthorizationEndpoint(t *testing.T) {
const ( const (
downstreamIssuer = "https://my-downstream-issuer.com/some-path"
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback" downstreamRedirectURIWithDifferentPort = "http://127.0.0.1:42/callback"
) )
@ -120,8 +121,6 @@ func TestAuthorizationEndpoint(t *testing.T) {
Scopes: []string{"scope1", "scope2"}, Scopes: []string{"scope1", "scope2"},
} }
issuer := "https://my-issuer.com/some-path"
// Configure fosite the same way that the production code would, using NullStorage to turn off storage. // Configure fosite the same way that the production code would, using NullStorage to turn off storage.
oauthStore := oidc.NullStorage{} oauthStore := oidc.NullStorage{}
hmacSecret := []byte("some secret - must have at least 32 bytes") hmacSecret := []byte("some secret - must have at least 32 bytes")
@ -233,7 +232,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-idp", "redirect_uri": downstreamIssuer + "/callback/some-idp",
}) })
} }
@ -270,7 +269,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
tests := []testCase{ tests := []testCase{
{ {
name: "happy path using GET without a CSRF cookie", name: "happy path using GET without a CSRF cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -288,7 +287,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "happy path using GET with a CSRF cookie", name: "happy path using GET with a CSRF cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -306,7 +305,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "happy path using POST", name: "happy path using POST",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -326,7 +325,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "happy path when downstream redirect uri matches what is configured for client except for the port number", name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -348,7 +347,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "downstream redirect uri does not match what is configured for client", name: "downstream redirect uri does not match what is configured for client",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -365,7 +364,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "downstream client does not exist", name: "downstream client does not exist",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -380,7 +379,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "response type is unsupported", name: "response type is unsupported",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -396,7 +395,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "downstream scopes do not match what is configured for client", name: "downstream scopes do not match what is configured for client",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -412,7 +411,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing response type in request", name: "missing response type in request",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -428,7 +427,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing client id in request", name: "missing client id in request",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -443,7 +442,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -459,7 +458,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -475,7 +474,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -491,7 +490,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -509,7 +508,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
// This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running // This is just one of the many OIDC validations run by fosite. This test is to ensure that we are running
// through that part of the fosite library. // through that part of the fosite library.
name: "prompt param is not allowed to have none and another legal value at the same time", name: "prompt param is not allowed to have none and another legal value at the same time",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -525,7 +524,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "OIDC validations are skipped when the openid scope was not requested", name: "OIDC validations are skipped when the openid scope was not requested",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -546,7 +545,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "state does not have enough entropy", name: "state does not have enough entropy",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -562,7 +561,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while encoding upstream state param", name: "error while encoding upstream state param",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -577,7 +576,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while encoding CSRF cookie value for new cookie", name: "error while encoding CSRF cookie value for new cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -592,7 +591,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while generating CSRF token", name: "error while generating CSRF token",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -607,7 +606,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while generating nonce", name: "error while generating nonce",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -622,7 +621,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while generating PKCE", name: "error while generating PKCE",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
@ -637,7 +636,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "error while decoding CSRF cookie", name: "error while decoding CSRF cookie",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
@ -653,7 +652,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "no upstream providers are configured", name: "no upstream providers are configured",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(), // empty idpListGetter: testutil.NewIDPListGetter(), // empty
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
@ -663,7 +662,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "too many upstream providers are configured", name: "too many upstream providers are configured",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
@ -673,7 +672,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "PUT is a bad method", name: "PUT is a bad method",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPut, method: http.MethodPut,
path: "/some/path", path: "/some/path",
@ -683,7 +682,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "PATCH is a bad method", name: "PATCH is a bad method",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPatch, method: http.MethodPatch,
path: "/some/path", path: "/some/path",
@ -693,7 +692,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
}, },
{ {
name: "DELETE is a bad method", name: "DELETE is a bad method",
issuer: issuer, issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodDelete, method: http.MethodDelete,
path: "/some/path", path: "/some/path",
@ -792,7 +791,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
"nonce": happyNonce, "nonce": happyNonce,
"code_challenge": expectedUpstreamCodeChallenge, "code_challenge": expectedUpstreamCodeChallenge,
"code_challenge_method": "S256", "code_challenge_method": "S256",
"redirect_uri": issuer + "/callback/some-other-idp", "redirect_uri": downstreamIssuer + "/callback/some-other-idp",
}, },
) )
test.wantBodyString = fmt.Sprintf(`<a href="%s">Found</a>.%s`, test.wantBodyString = fmt.Sprintf(`<a href="%s">Found</a>.%s`,

View File

@ -36,7 +36,12 @@ const (
downstreamGroupsClaim = "oidc.pinniped.dev/groups" downstreamGroupsClaim = "oidc.pinniped.dev/groups"
) )
func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder) http.Handler { func NewHandler(
downstreamIssuer string,
idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider,
stateDecoder, cookieDecoder oidc.Decoder,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder) state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil { if err != nil {
@ -106,7 +111,7 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi
now := time.Now() now := time.Now()
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{
Issuer: "https://fosite.my-application.com", // TODO use the right value here Issuer: downstreamIssuer,
Subject: username, Subject: username,
Audience: []string{"my-client"}, // TODO use the right value here Audience: []string{"my-client"}, // TODO use the right value here
ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here

View File

@ -33,8 +33,8 @@ const (
func TestCallbackEndpoint(t *testing.T) { func TestCallbackEndpoint(t *testing.T) {
const ( const (
downstreamIssuer = "https://my-downstream-issuer.com/path"
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
happyUpstreamAuthcode = "upstream-auth-code" happyUpstreamAuthcode = "upstream-auth-code"
) )
@ -207,9 +207,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantStatus int wantStatus int
wantBody string wantBody string
wantRedirectLocationRegexp string wantRedirectLocationRegexp string
// TODO: I am unused... wantGrantedOpenidScope bool
wantAuthcodeStored bool
wantGrantedOpenidScope bool
wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs
}{ }{
@ -221,7 +219,6 @@ func TestCallbackEndpoint(t *testing.T) {
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantRedirectLocationRegexp: happyRedirectLocationRegexp,
wantAuthcodeStored: true,
wantGrantedOpenidScope: true, wantGrantedOpenidScope: true,
wantBody: "", wantBody: "",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
@ -234,7 +231,6 @@ func TestCallbackEndpoint(t *testing.T) {
csrfCookie: happyCSRFCookie, csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound, wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantRedirectLocationRegexp: happyRedirectLocationRegexp,
wantAuthcodeStored: true,
wantGrantedOpenidScope: true, wantGrantedOpenidScope: true,
wantBody: "", wantBody: "",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
@ -396,7 +392,7 @@ func TestCallbackEndpoint(t *testing.T) {
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
idpListGetter := testutil.NewIDPListGetter(&test.idp) idpListGetter := testutil.NewIDPListGetter(&test.idp)
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) subject := NewHandler(downstreamIssuer, idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec)
req := httptest.NewRequest(test.method, test.path, nil) req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" { if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", test.csrfCookie)
@ -406,9 +402,15 @@ func TestCallbackEndpoint(t *testing.T) {
t.Logf("response: %#v", rsp) t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
require.Equal(t, test.wantStatus, rsp.Code) if test.wantExchangeAndValidateTokensCall != nil {
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
test.wantExchangeAndValidateTokensCall.Ctx = req.Context()
require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0))
} else {
require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
}
require.False(t, test.wantBody != "" && test.wantRedirectLocationRegexp != "", "test cannot set both body and redirect assertions") require.Equal(t, test.wantStatus, rsp.Code)
if test.wantBody != "" { if test.wantBody != "" {
require.Equal(t, test.wantBody, rsp.Body.String()) require.Equal(t, test.wantBody, rsp.Body.String())
@ -448,19 +450,12 @@ func TestCallbackEndpoint(t *testing.T) {
} else { } else {
require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") require.NotContains(t, storedRequest.GetGrantedScopes(), "openid")
} }
require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer)
require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject)
require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"])
} else { } else {
require.Empty(t, rsp.Header().Values("Location")) require.Empty(t, rsp.Header().Values("Location"))
} }
if test.wantExchangeAndValidateTokensCall != nil {
require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
test.wantExchangeAndValidateTokensCall.Ctx = req.Context()
require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0))
} else {
require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount())
}
}) })
} }
} }