token_handler_test.go: Refactor how we specify the expected results

- This is to make it easier for the token exchange branch to also edit
  this test without causing a lot of merge conflicts with the
  refresh token branch, to enable parallel development of closely
  related stories.
This commit is contained in:
Ryan Richard 2020-12-08 18:10:55 -08:00
parent 170982a688
commit ef4ef583dc

View File

@ -234,10 +234,11 @@ type authcodeExchangeInputs struct {
}, },
) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey)
wantStatus int wantStatus int
wantBodyFields []string wantSuccessBodyFields []string
wantRequestedScopes []string wantErrorResponseBody string
wantExactBody string wantRequestedScopes []string
wantGrantedScopes []string
} }
func TestTokenEndpoint(t *testing.T) { func TestTokenEndpoint(t *testing.T) {
@ -249,9 +250,10 @@ func TestTokenEndpoint(t *testing.T) {
{ {
name: "request is valid and tokens are issued", name: "request is valid and tokens are issued",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantBodyFields: []string{"id_token", "access_token", "token_type", "scope", "expires_in"}, // no refresh token wantSuccessBodyFields: []string{"id_token", "access_token", "token_type", "scope", "expires_in"}, // no refresh token
wantRequestedScopes: []string{"openid", "profile", "email"}, wantRequestedScopes: []string{"openid", "profile", "email"},
wantGrantedScopes: []string{"openid"},
}, },
}, },
{ {
@ -260,9 +262,10 @@ func TestTokenEndpoint(t *testing.T) {
modifyAuthRequest: func(authRequest *http.Request) { modifyAuthRequest: func(authRequest *http.Request) {
authRequest.Form.Set("scope", "profile email") authRequest.Form.Set("scope", "profile email")
}, },
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantBodyFields: []string{"access_token", "token_type", "scope", "expires_in"}, // no id or refresh tokens wantSuccessBodyFields: []string{"access_token", "token_type", "scope", "expires_in"}, // no id or refresh tokens
wantRequestedScopes: []string{"profile", "email"}, wantRequestedScopes: []string{"profile", "email"},
wantGrantedScopes: []string{},
}, },
}, },
{ {
@ -271,9 +274,10 @@ func TestTokenEndpoint(t *testing.T) {
modifyAuthRequest: func(authRequest *http.Request) { modifyAuthRequest: func(authRequest *http.Request) {
authRequest.Form.Set("scope", "openid offline_access") authRequest.Form.Set("scope", "openid offline_access")
}, },
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantBodyFields: []string{"id_token", "access_token", "token_type", "scope", "expires_in", "refresh_token"}, // all possible tokens wantSuccessBodyFields: []string{"id_token", "access_token", "token_type", "scope", "expires_in", "refresh_token"}, // all possible tokens
wantRequestedScopes: []string{"openid", "offline_access"}, wantRequestedScopes: []string{"openid", "offline_access"},
wantGrantedScopes: []string{"openid", "offline_access"},
}, },
}, },
{ {
@ -282,9 +286,10 @@ func TestTokenEndpoint(t *testing.T) {
modifyAuthRequest: func(authRequest *http.Request) { modifyAuthRequest: func(authRequest *http.Request) {
authRequest.Form.Set("scope", "offline_access") authRequest.Form.Set("scope", "offline_access")
}, },
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantBodyFields: []string{"access_token", "token_type", "scope", "expires_in", "refresh_token"}, // no id token wantSuccessBodyFields: []string{"access_token", "token_type", "scope", "expires_in", "refresh_token"}, // no id token
wantRequestedScopes: []string{"offline_access"}, wantRequestedScopes: []string{"offline_access"},
wantGrantedScopes: []string{"offline_access"},
}, },
}, },
@ -292,41 +297,41 @@ func TestTokenEndpoint(t *testing.T) {
{ {
name: "GET method is wrong", name: "GET method is wrong",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodGet }, modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodGet },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidMethodErrorBody("GET"), wantErrorResponseBody: fositeInvalidMethodErrorBody("GET"),
}, },
}, },
{ {
name: "PUT method is wrong", name: "PUT method is wrong",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodPut }, modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodPut },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidMethodErrorBody("PUT"), wantErrorResponseBody: fositeInvalidMethodErrorBody("PUT"),
}, },
}, },
{ {
name: "PATCH method is wrong", name: "PATCH method is wrong",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodPatch }, modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodPatch },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidMethodErrorBody("PATCH"), wantErrorResponseBody: fositeInvalidMethodErrorBody("PATCH"),
}, },
}, },
{ {
name: "DELETE method is wrong", name: "DELETE method is wrong",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodDelete }, modifyTokenRequest: func(r *http.Request, authCode string) { r.Method = http.MethodDelete },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidMethodErrorBody("DELETE"), wantErrorResponseBody: fositeInvalidMethodErrorBody("DELETE"),
}, },
}, },
{ {
name: "content type is invalid", name: "content type is invalid",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyTokenRequest: func(r *http.Request, authCode string) { r.Header.Set("Content-Type", "text/plain") }, modifyTokenRequest: func(r *http.Request, authCode string) { r.Header.Set("Content-Type", "text/plain") },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeEmptyPayloadErrorBody, wantErrorResponseBody: fositeEmptyPayloadErrorBody,
}, },
}, },
{ {
@ -335,16 +340,16 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = ioutil.NopCloser(strings.NewReader("this newline character is not allowed in a form serialization: \n")) r.Body = ioutil.NopCloser(strings.NewReader("this newline character is not allowed in a form serialization: \n"))
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeMissingGrantTypeErrorBody, wantErrorResponseBody: fositeMissingGrantTypeErrorBody,
}, },
}, },
{ {
name: "payload is empty", name: "payload is empty",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
modifyTokenRequest: func(r *http.Request, authCode string) { r.Body = nil }, modifyTokenRequest: func(r *http.Request, authCode string) { r.Body = nil },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidPayloadErrorBody, wantErrorResponseBody: fositeInvalidPayloadErrorBody,
}, },
}, },
{ {
@ -353,8 +358,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithGrantType("").ReadCloser() r.Body = happyBody(authCode).WithGrantType("").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeMissingGrantTypeErrorBody, wantErrorResponseBody: fositeMissingGrantTypeErrorBody,
}, },
}, },
{ {
@ -363,8 +368,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithGrantType("bogus").ReadCloser() r.Body = happyBody(authCode).WithGrantType("bogus").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidRequestErrorBody, wantErrorResponseBody: fositeInvalidRequestErrorBody,
}, },
}, },
{ {
@ -373,8 +378,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithClientID("").ReadCloser() r.Body = happyBody(authCode).WithClientID("").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeMissingClientErrorBody, wantErrorResponseBody: fositeMissingClientErrorBody,
}, },
}, },
{ {
@ -383,8 +388,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithClientID("bogus").ReadCloser() r.Body = happyBody(authCode).WithClientID("bogus").ReadCloser()
}, },
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantExactBody: fositeInvalidClientErrorBody, wantErrorResponseBody: fositeInvalidClientErrorBody,
}, },
}, },
{ {
@ -393,8 +398,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithAuthCode("").ReadCloser() r.Body = happyBody(authCode).WithAuthCode("").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidAuthCodeErrorBody, wantErrorResponseBody: fositeInvalidAuthCodeErrorBody,
}, },
}, },
{ {
@ -403,8 +408,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithAuthCode("bogus").ReadCloser() r.Body = happyBody(authCode).WithAuthCode("bogus").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidAuthCodeErrorBody, wantErrorResponseBody: fositeInvalidAuthCodeErrorBody,
}, },
}, },
{ {
@ -424,8 +429,8 @@ func TestTokenEndpoint(t *testing.T) {
err := s.InvalidateAuthorizeCodeSession(context.Background(), getFositeDataSignature(t, authCode)) err := s.InvalidateAuthorizeCodeSession(context.Background(), getFositeDataSignature(t, authCode))
require.NoError(t, err) require.NoError(t, err)
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeReusedAuthCodeErrorBody, wantErrorResponseBody: fositeReusedAuthCodeErrorBody,
}, },
}, },
{ {
@ -434,8 +439,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithRedirectURI("").ReadCloser() r.Body = happyBody(authCode).WithRedirectURI("").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidRedirectURIErrorBody, wantErrorResponseBody: fositeInvalidRedirectURIErrorBody,
}, },
}, },
{ {
@ -444,8 +449,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithRedirectURI("bogus").ReadCloser() r.Body = happyBody(authCode).WithRedirectURI("bogus").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeInvalidRedirectURIErrorBody, wantErrorResponseBody: fositeInvalidRedirectURIErrorBody,
}, },
}, },
{ {
@ -454,8 +459,8 @@ func TestTokenEndpoint(t *testing.T) {
modifyTokenRequest: func(r *http.Request, authCode string) { modifyTokenRequest: func(r *http.Request, authCode string) {
r.Body = happyBody(authCode).WithPKCE("").ReadCloser() r.Body = happyBody(authCode).WithPKCE("").ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeMissingPKCEVerifierErrorBody, wantErrorResponseBody: fositeMissingPKCEVerifierErrorBody,
}, },
}, },
{ {
@ -466,16 +471,16 @@ func TestTokenEndpoint(t *testing.T) {
"bogus-verifier-that-is-at-least-43-characters-for-the-sake-of-entropy", "bogus-verifier-that-is-at-least-43-characters-for-the-sake-of-entropy",
).ReadCloser() ).ReadCloser()
}, },
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
wantExactBody: fositeWrongPKCEVerifierErrorBody, wantErrorResponseBody: fositeWrongPKCEVerifierErrorBody,
}, },
}, },
{ {
name: "private signing key for JWTs has not yet been provided by the controller who is responsible for dynamically providing it", name: "private signing key for JWTs has not yet been provided by the controller who is responsible for dynamically providing it",
authcodeExchange: authcodeExchangeInputs{ authcodeExchange: authcodeExchangeInputs{
makeOathHelper: makeOauthHelperWithNilPrivateJWTSigningKey, makeOathHelper: makeOauthHelperWithNilPrivateJWTSigningKey,
wantStatus: http.StatusServiceUnavailable, wantStatus: http.StatusServiceUnavailable,
wantExactBody: fositeTemporarilyUnavailableErrorBody, wantErrorResponseBody: fositeTemporarilyUnavailableErrorBody,
}, },
}, },
} }
@ -489,10 +494,8 @@ func TestTokenEndpoint(t *testing.T) {
func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) { func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
authcodeExchange authcodeExchangeInputs authcodeExchange authcodeExchangeInputs
wantGrantedOpenidScope bool
wantGrantedOfflineAccessScope bool
}{ }{
{ {
name: "authcode exchange succeeds once and then fails when the same authcode is used again", name: "authcode exchange succeeds once and then fails when the same authcode is used again",
@ -500,12 +503,11 @@ func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) {
modifyAuthRequest: func(authRequest *http.Request) { modifyAuthRequest: func(authRequest *http.Request) {
authRequest.Form.Set("scope", "openid offline_access profile email") authRequest.Form.Set("scope", "openid offline_access profile email")
}, },
wantStatus: http.StatusOK, wantStatus: http.StatusOK,
wantBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"}, wantSuccessBodyFields: []string{"id_token", "refresh_token", "access_token", "token_type", "expires_in", "scope"},
wantRequestedScopes: []string{"openid", "offline_access", "profile", "email"}, wantRequestedScopes: []string{"openid", "offline_access", "profile", "email"},
wantGrantedScopes: []string{"openid", "offline_access"},
}, },
wantGrantedOpenidScope: true,
wantGrantedOfflineAccessScope: true,
}, },
} }
for _, test := range tests { for _, test := range tests {
@ -537,7 +539,7 @@ func TestTokenEndpointWhenAuthcodeIsUsedTwice(t *testing.T) {
// This was previously invalidated by the first request, so it remains invalidated // This was previously invalidated by the first request, so it remains invalidated
requireInvalidPKCEStorage(t, authCode, oauthStore) requireInvalidPKCEStorage(t, authCode, oauthStore)
// Fosite never cleans up OpenID Connect session storage, so it is still there // Fosite never cleans up OpenID Connect session storage, so it is still there
requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.authcodeExchange.wantRequestedScopes, test.wantGrantedOpenidScope, test.wantGrantedOfflineAccessScope) requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.authcodeExchange.wantRequestedScopes, test.authcodeExchange.wantGrantedScopes)
// Check that the access token and refresh token storage were both deleted, and the number of other storage objects did not change. // Check that the access token and refresh token storage were both deleted, and the number of other storage objects did not change.
testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1)
@ -602,21 +604,23 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) (
t.Logf("response: %#v", rsp) t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
require.Equal(t, test.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), "application/json") testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), "application/json")
if test.wantBodyFields != nil { require.Equal(t, test.wantStatus, rsp.Code)
if test.wantStatus == http.StatusOK {
require.NotNil(t, test.wantSuccessBodyFields, "problem with test table setup: wanted success but did not specify expected response body")
var parsedResponseBody map[string]interface{} var parsedResponseBody map[string]interface{}
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody)) require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody))
require.ElementsMatch(t, test.wantBodyFields, getMapKeys(parsedResponseBody)) require.ElementsMatch(t, test.wantSuccessBodyFields, getMapKeys(parsedResponseBody))
wantIDToken := contains(test.wantBodyFields, "id_token") wantIDToken := contains(test.wantSuccessBodyFields, "id_token")
wantRefreshToken := contains(test.wantBodyFields, "refresh_token") wantRefreshToken := contains(test.wantSuccessBodyFields, "refresh_token")
code := req.PostForm.Get("code") requireInvalidAuthCodeStorage(t, authCode, oauthStore)
requireInvalidAuthCodeStorage(t, code, oauthStore) requireValidAccessTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes)
requireValidAccessTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, wantIDToken, wantRefreshToken) requireInvalidPKCEStorage(t, authCode, oauthStore)
requireInvalidPKCEStorage(t, code, oauthStore) requireValidOIDCStorage(t, parsedResponseBody, authCode, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes)
requireValidOIDCStorage(t, parsedResponseBody, code, oauthStore, test.wantRequestedScopes, wantIDToken, wantRefreshToken)
expectedNumberOfRefreshTokenSessionsStored := 0 expectedNumberOfRefreshTokenSessionsStored := 0
if wantRefreshToken { if wantRefreshToken {
@ -628,7 +632,7 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) (
requireValidIDToken(t, parsedResponseBody, jwtSigningKey) requireValidIDToken(t, parsedResponseBody, jwtSigningKey)
} }
if wantRefreshToken { if wantRefreshToken {
requireValidRefreshTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, wantIDToken, wantRefreshToken) requireValidRefreshTokenStorage(t, parsedResponseBody, oauthStore, test.wantRequestedScopes, test.wantGrantedScopes)
} }
testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1)
@ -638,7 +642,9 @@ func exchangeAuthcodeForTokens(t *testing.T, test authcodeExchangeInputs) (
testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: openidconnect.TypeLabelValue}, expectedNumberOfIDSessionsStored) testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{crud.SecretLabelKey: openidconnect.TypeLabelValue}, expectedNumberOfIDSessionsStored)
testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{}, 2+expectedNumberOfRefreshTokenSessionsStored+expectedNumberOfIDSessionsStored) testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secrets, labels.Set{}, 2+expectedNumberOfRefreshTokenSessionsStored+expectedNumberOfIDSessionsStored)
} else { } else {
require.JSONEq(t, test.wantExactBody, rsp.Body.String()) require.NotNil(t, test.wantErrorResponseBody, "problem with test table setup: wanted failure but did not specify failure response body")
require.JSONEq(t, test.wantErrorResponseBody, rsp.Body.String())
} }
return subject, rsp, authCode, secrets, oauthStore return subject, rsp, authCode, secrets, oauthStore
@ -736,9 +742,8 @@ func makeOauthHelperWithNilPrivateJWTSigningKey(
return oauthHelper, authResponder.GetCode(), nil return oauthHelper, authResponder.GetCode(), nil
} }
// Simulate the auth endpoint running so Fosite code will fill the store with realistic values.
func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Request, oauthHelper fosite.OAuth2Provider) fosite.AuthorizeResponder { func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Request, oauthHelper fosite.OAuth2Provider) fosite.AuthorizeResponder {
// Simulate the auth endpoint running so Fosite code will fill the store with realistic values.
//
// We only set the fields in the session that Fosite wants us to set. // We only set the fields in the session that Fosite wants us to set.
ctx := context.Background() ctx := context.Background()
session := &openid.DefaultSession{ session := &openid.DefaultSession{
@ -797,8 +802,7 @@ func requireValidRefreshTokenStorage(
body map[string]interface{}, body map[string]interface{},
storage oauth2.CoreStorage, storage oauth2.CoreStorage,
wantRequestedScopes []string, wantRequestedScopes []string,
wantGrantedOpenidScope bool, wantGrantedScopes []string,
wantGrantedOfflineAccessScope bool,
) { ) {
t.Helper() t.Helper()
@ -817,8 +821,7 @@ func requireValidRefreshTokenStorage(
storedRequest, storedRequest,
storedRequest.Sanitize([]string{}).GetRequestForm(), storedRequest.Sanitize([]string{}).GetRequestForm(),
wantRequestedScopes, wantRequestedScopes,
wantGrantedOpenidScope, wantGrantedScopes,
wantGrantedOfflineAccessScope,
true, true,
) )
} }
@ -828,8 +831,7 @@ func requireValidAccessTokenStorage(
body map[string]interface{}, body map[string]interface{},
storage oauth2.CoreStorage, storage oauth2.CoreStorage,
wantRequestedScopes []string, wantRequestedScopes []string,
wantGrantedOpenidScope bool, wantGrantedScopes []string,
wantGrantedOfflineAccessScope bool,
) { ) {
t.Helper() t.Helper()
@ -859,7 +861,7 @@ func requireValidAccessTokenStorage(
require.True(t, ok) require.True(t, ok)
actualGrantedScopesString, ok := scopes.(string) actualGrantedScopesString, ok := scopes.(string)
require.Truef(t, ok, "wanted scopes to be an string, but got %T", scopes) require.Truef(t, ok, "wanted scopes to be an string, but got %T", scopes)
require.Equal(t, strings.Join(wantGrantedScopes(wantGrantedOpenidScope, wantGrantedOfflineAccessScope), " "), actualGrantedScopesString) require.Equal(t, strings.Join(wantGrantedScopes, " "), actualGrantedScopesString)
// Fosite stores access tokens without any of the original request form parameters. // Fosite stores access tokens without any of the original request form parameters.
requireValidStoredRequest( requireValidStoredRequest(
@ -867,23 +869,11 @@ func requireValidAccessTokenStorage(
storedRequest, storedRequest,
storedRequest.Sanitize([]string{}).GetRequestForm(), storedRequest.Sanitize([]string{}).GetRequestForm(),
wantRequestedScopes, wantRequestedScopes,
wantGrantedOpenidScope, wantGrantedScopes,
wantGrantedOfflineAccessScope,
true, true,
) )
} }
func wantGrantedScopes(wantGrantedOpenidScope, wantGrantedOfflineAccessScope bool) []string {
scopesWanted := []string{}
if wantGrantedOpenidScope {
scopesWanted = append(scopesWanted, "openid")
}
if wantGrantedOfflineAccessScope {
scopesWanted = append(scopesWanted, "offline_access")
}
return scopesWanted
}
func requireInvalidAccessTokenStorage( func requireInvalidAccessTokenStorage(
t *testing.T, t *testing.T,
body map[string]interface{}, body map[string]interface{},
@ -919,12 +909,11 @@ func requireValidOIDCStorage(
code string, code string,
storage openid.OpenIDConnectRequestStorage, storage openid.OpenIDConnectRequestStorage,
wantRequestedScopes []string, wantRequestedScopes []string,
wantGrantedOpenidScope bool, wantGrantedScopes []string,
wantGrantedOfflineAccessScope bool,
) { ) {
t.Helper() t.Helper()
if wantGrantedOpenidScope { if contains(wantGrantedScopes, "openid") {
// Make sure the OIDC session is still there. Note that Fosite stores OIDC sessions using the full auth code as a key. // Make sure the OIDC session is still there. Note that Fosite stores OIDC sessions using the full auth code as a key.
storedRequest, err := storage.GetOpenIDConnectSession(context.Background(), code, nil) storedRequest, err := storage.GetOpenIDConnectSession(context.Background(), code, nil)
require.NoError(t, err) require.NoError(t, err)
@ -941,8 +930,7 @@ func requireValidOIDCStorage(
storedRequest, storedRequest,
storedRequest.Sanitize([]string{"nonce"}).GetRequestForm(), storedRequest.Sanitize([]string{"nonce"}).GetRequestForm(),
wantRequestedScopes, wantRequestedScopes,
wantGrantedOpenidScope, wantGrantedScopes,
wantGrantedOfflineAccessScope,
false, false,
) )
} else { } else {
@ -956,8 +944,7 @@ func requireValidStoredRequest(
request fosite.Requester, request fosite.Requester,
wantRequestForm url.Values, wantRequestForm url.Values,
wantRequestedScopes []string, wantRequestedScopes []string,
wantGrantedOpenidScope bool, wantGrantedScopes []string,
wantGrantedOfflineAccessScope bool,
wantAccessTokenExpiresAt bool, wantAccessTokenExpiresAt bool,
) { ) {
t.Helper() t.Helper()
@ -967,7 +954,7 @@ func requireValidStoredRequest(
testutil.RequireTimeInDelta(t, request.GetRequestedAt(), time.Now().UTC(), timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, request.GetRequestedAt(), time.Now().UTC(), timeComparisonFudgeSeconds*time.Second)
require.Equal(t, goodClient, request.GetClient().GetID()) require.Equal(t, goodClient, request.GetClient().GetID())
require.Equal(t, fosite.Arguments(wantRequestedScopes), request.GetRequestedScopes()) require.Equal(t, fosite.Arguments(wantRequestedScopes), request.GetRequestedScopes())
require.Equal(t, fosite.Arguments(wantGrantedScopes(wantGrantedOpenidScope, wantGrantedOfflineAccessScope)), request.GetGrantedScopes()) require.Equal(t, fosite.Arguments(wantGrantedScopes), request.GetGrantedScopes())
require.Empty(t, request.GetRequestedAudience()) require.Empty(t, request.GetRequestedAudience())
require.Empty(t, request.GetGrantedAudience()) require.Empty(t, request.GetGrantedAudience())
require.Equal(t, wantRequestForm, request.GetRequestForm()) // Fosite stores access token request without form require.Equal(t, wantRequestForm, request.GetRequestForm()) // Fosite stores access token request without form
@ -977,7 +964,7 @@ func requireValidStoredRequest(
require.Truef(t, ok, "could not cast %T to %T", request.GetSession(), &openid.DefaultSession{}) require.Truef(t, ok, "could not cast %T to %T", request.GetSession(), &openid.DefaultSession{})
// Assert that the session claims are what we think they should be, but only if we are doing OIDC. // Assert that the session claims are what we think they should be, but only if we are doing OIDC.
if wantGrantedOpenidScope { if contains(wantGrantedScopes, "openid") {
claims := session.Claims claims := session.Claims
require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field. require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field.
require.Equal(t, goodSubject, claims.Subject) require.Equal(t, goodSubject, claims.Subject)