token_handler_test.go: Add tests for username and groups custom claims

This commit is contained in:
Ryan Richard 2020-12-14 18:27:14 -08:00
parent afcd5e3e36
commit 16dfab0aff
3 changed files with 88 additions and 39 deletions

View File

@ -23,23 +23,6 @@ import (
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
const (
// The name of the issuer claim specified in the OIDC spec.
idTokenIssuerClaim = "iss"
// The name of the subject claim specified in the OIDC spec.
idTokenSubjectClaim = "sub"
// idTokenUsernameClaim is a custom claim in the downstream ID token
// whose value is mapped from a claim in the upstream token.
// By default the value is the same as the downstream subject claim's.
idTokenUsernameClaim = "username"
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
downstreamGroupsClaim = "groups"
)
func NewHandler( func NewHandler(
idpListGetter oidc.IDPListGetter, idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
@ -199,7 +182,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
) (string, string, error) { ) (string, string, error) {
// The spec says the "sub" claim is only unique per issuer, // The spec says the "sub" claim is only unique per issuer,
// so we will prepend the issuer string to make it globally unique. // so we will prepend the issuer string to make it globally unique.
upstreamIssuer := idTokenClaims[idTokenIssuerClaim] upstreamIssuer := idTokenClaims[oidc.IDTokenIssuerClaim]
if upstreamIssuer == "" { if upstreamIssuer == "" {
plog.Warning( plog.Warning(
"issuer claim in upstream ID token missing", "issuer claim in upstream ID token missing",
@ -218,7 +201,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
} }
subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim] subjectAsInterface, ok := idTokenClaims[oidc.IDTokenSubjectClaim]
if !ok { if !ok {
plog.Warning( plog.Warning(
"no subject claim in upstream ID token", "no subject claim in upstream ID token",
@ -236,7 +219,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format") return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
} }
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject) subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, upstreamSubject)
usernameClaim := upstreamIDPConfig.GetUsernameClaim() usernameClaim := upstreamIDPConfig.GetUsernameClaim()
if usernameClaim == "" { if usernameClaim == "" {
@ -312,10 +295,10 @@ func makeDownstreamSession(subject string, username string, groups []string) *op
}, },
} }
openIDSession.Claims.Extra = map[string]interface{}{ openIDSession.Claims.Extra = map[string]interface{}{
idTokenUsernameClaim: username, oidc.DownstreamUsernameClaim: username,
} }
if groups != nil { if groups != nil {
openIDSession.Claims.Extra[downstreamGroupsClaim] = groups openIDSession.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
} }
return openIDSession return openIDSession
} }

View File

@ -44,6 +44,21 @@ const (
// CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF // CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF
// cookie contents. // cookie contents.
CSRFCookieEncodingName = "csrf" CSRFCookieEncodingName = "csrf"
// The name of the issuer claim specified in the OIDC spec.
IDTokenIssuerClaim = "iss"
// The name of the subject claim specified in the OIDC spec.
IDTokenSubjectClaim = "sub"
// DownstreamUsernameClaim is a custom claim in the downstream ID token
// whose value is mapped from a claim in the upstream token.
// By default the value is the same as the downstream subject claim's.
DownstreamUsernameClaim = "username"
// DownstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
DownstreamGroupsClaim = "groups"
) )
// Encoder is the encoding side of the securecookie.Codec interface. // Encoder is the encoding side of the securecookie.Codec interface.

View File

@ -53,8 +53,9 @@ const (
goodRedirectURI = "http://127.0.0.1/callback" goodRedirectURI = "http://127.0.0.1/callback"
goodPKCECodeVerifier = "some-pkce-verifier-that-must-be-at-least-43-characters-to-meet-entropy-requirements" goodPKCECodeVerifier = "some-pkce-verifier-that-must-be-at-least-43-characters-to-meet-entropy-requirements"
goodNonce = "some-nonce-value-with-enough-bytes-to-exceed-min-allowed" goodNonce = "some-nonce-value-with-enough-bytes-to-exceed-min-allowed"
goodSubject = "some-subject" goodSubject = "https://issuer?sub=some-subject"
goodUsername = "some-username" goodUsername = "some-username"
goodGroups = "group1,groups2"
hmacSecret = "this needs to be at least 32 characters to meet entropy requirements" hmacSecret = "this needs to be at least 32 characters to meet entropy requirements"
@ -781,11 +782,13 @@ func TestTokenExchange(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange) t.Parallel()
var parsedResponseBody map[string]interface{}
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody))
request := happyTokenExchangeRequest(test.requestedAudience, parsedResponseBody["access_token"].(string)) subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange)
var parsedAuthcodeExchangeResponseBody map[string]interface{}
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody))
request := happyTokenExchangeRequest(test.requestedAudience, parsedAuthcodeExchangeResponseBody["access_token"].(string))
if test.modifyStorage != nil { if test.modifyStorage != nil {
test.modifyStorage(t, storage, request) test.modifyStorage(t, storage, request)
} }
@ -801,6 +804,10 @@ func TestTokenExchange(t *testing.T) {
existingSecrets, err := secrets.List(context.Background(), metav1.ListOptions{}) existingSecrets, err := secrets.List(context.Background(), metav1.ListOptions{})
require.NoError(t, err) require.NoError(t, err)
// Wait one second before performing the token exchange so we can see that the new ID token has new issued
// at and expires at dates which are newer than the old tokens.
time.Sleep(1 * time.Second)
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
t.Logf("response: %#v", rsp) t.Logf("response: %#v", rsp)
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
@ -816,6 +823,12 @@ func TestTokenExchange(t *testing.T) {
return return
} }
claimsOfFirstIDToken := map[string]interface{}{}
originalIDToken := parsedAuthcodeExchangeResponseBody["id_token"].(string)
firstIDTokenDecoded, _ := josejwt.ParseSigned(originalIDToken)
err = firstIDTokenDecoded.UnsafeClaimsWithoutVerification(&claimsOfFirstIDToken)
require.NoError(t, err)
var responseBody map[string]interface{} var responseBody map[string]interface{}
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &responseBody)) require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &responseBody))
@ -823,18 +836,43 @@ func TestTokenExchange(t *testing.T) {
require.Equal(t, "N_A", responseBody["token_type"]) require.Equal(t, "N_A", responseBody["token_type"])
require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", responseBody["issued_token_type"]) require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", responseBody["issued_token_type"])
// Assert that the returned token has expected claims. // Parse the returned token.
parsedJWT, err := jose.ParseSigned(responseBody["access_token"].(string)) parsedJWT, err := jose.ParseSigned(responseBody["access_token"].(string))
require.NoError(t, err) require.NoError(t, err)
var tokenClaims map[string]interface{} var tokenClaims map[string]interface{}
require.NoError(t, json.Unmarshal(parsedJWT.UnsafePayloadWithoutVerification(), &tokenClaims)) require.NoError(t, json.Unmarshal(parsedJWT.UnsafePayloadWithoutVerification(), &tokenClaims))
require.Contains(t, tokenClaims, "iat")
require.Contains(t, tokenClaims, "rat") // Make sure that these are the only fields in the token.
require.Contains(t, tokenClaims, "jti") idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat", "groups", "username"}
require.ElementsMatch(t, idTokenFields, getMapKeys(tokenClaims))
// Assert that the returned token has expected claims values.
require.NotEmpty(t, tokenClaims["jti"])
require.NotEmpty(t, tokenClaims["auth_time"])
require.NotEmpty(t, tokenClaims["exp"])
require.NotEmpty(t, tokenClaims["iat"])
require.NotEmpty(t, tokenClaims["rat"])
require.Empty(t, tokenClaims["nonce"]) // ID tokens only contain nonce during an authcode exchange
require.Len(t, tokenClaims["aud"], 1) require.Len(t, tokenClaims["aud"], 1)
require.Contains(t, tokenClaims["aud"], test.requestedAudience) require.Contains(t, tokenClaims["aud"], test.requestedAudience)
require.Equal(t, goodSubject, tokenClaims["sub"]) require.Equal(t, goodSubject, tokenClaims["sub"])
require.Equal(t, goodIssuer, tokenClaims["iss"]) require.Equal(t, goodIssuer, tokenClaims["iss"])
require.Equal(t, goodUsername, tokenClaims["username"])
require.Equal(t, goodGroups, tokenClaims["groups"])
// Also assert that some are the same as the original downstream ID token.
requireClaimsAreEqual(t, "iss", claimsOfFirstIDToken, tokenClaims) // issuer
requireClaimsAreEqual(t, "sub", claimsOfFirstIDToken, tokenClaims) // subject
requireClaimsAreEqual(t, "rat", claimsOfFirstIDToken, tokenClaims) // requested at
requireClaimsAreEqual(t, "auth_time", claimsOfFirstIDToken, tokenClaims) // auth time
// Also assert which are the different from the original downstream ID token.
requireClaimsAreNotEqual(t, "jti", claimsOfFirstIDToken, tokenClaims) // JWT ID
requireClaimsAreNotEqual(t, "aud", claimsOfFirstIDToken, tokenClaims) // audience
requireClaimsAreNotEqual(t, "exp", claimsOfFirstIDToken, tokenClaims) // expires at
require.Greater(t, tokenClaims["exp"], claimsOfFirstIDToken["exp"])
requireClaimsAreNotEqual(t, "iat", claimsOfFirstIDToken, tokenClaims) // issued at
require.Greater(t, tokenClaims["iat"], claimsOfFirstIDToken["iat"])
// Assert that nothing in storage has been modified. // Assert that nothing in storage has been modified.
newSecrets, err := secrets.List(context.Background(), metav1.ListOptions{}) newSecrets, err := secrets.List(context.Background(), metav1.ListOptions{})
@ -1386,11 +1424,15 @@ func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Reques
session := &openid.DefaultSession{ session := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{
Subject: goodSubject, Subject: goodSubject,
AuthTime: goodAuthTime,
RequestedAt: goodRequestedAtTime, RequestedAt: goodRequestedAtTime,
AuthTime: goodAuthTime,
Extra: map[string]interface{}{
oidc.DownstreamUsernameClaim: goodUsername,
oidc.DownstreamGroupsClaim: goodGroups,
}, },
Subject: goodSubject, },
Username: goodUsername, Subject: "", // not used, note that callback_handler.go does not set this
Username: "", // not used, note that callback_handler.go does not set this
} }
authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest) authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest)
require.NoError(t, err) require.NoError(t, err)
@ -1609,6 +1651,12 @@ func requireValidStoredRequest(
require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field. require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field.
require.Equal(t, goodSubject, claims.Subject) require.Equal(t, goodSubject, claims.Subject)
// Our custom claims from the authorize endpoint should still be set.
require.Equal(t, map[string]interface{}{
"username": goodUsername,
"groups": goodGroups,
}, claims.Extra)
// We are in charge of setting these fields. For the purpose of testing, we ensure that the // We are in charge of setting these fields. For the purpose of testing, we ensure that the
// sentinel test value is set correctly. // sentinel test value is set correctly.
require.Equal(t, goodRequestedAtTime, claims.RequestedAt) require.Equal(t, goodRequestedAtTime, claims.RequestedAt)
@ -1633,7 +1681,6 @@ func requireValidStoredRequest(
require.Empty(t, claims.AuthenticationContextClassReference) require.Empty(t, claims.AuthenticationContextClassReference)
require.Empty(t, claims.AuthenticationMethodsReference) require.Empty(t, claims.AuthenticationMethodsReference)
require.Empty(t, claims.CodeHash) require.Empty(t, claims.CodeHash)
require.Empty(t, claims.Extra)
} }
// Assert that the session headers are what we think they should be. // Assert that the session headers are what we think they should be.
@ -1664,9 +1711,9 @@ func requireValidStoredRequest(
require.False(t, ok, "expected session to not hold expiration time for access token, but it did") require.False(t, ok, "expected session to not hold expiration time for access token, but it did")
} }
// Assert that the session's username and subject are correct. // We don't use these, so they should be empty.
require.Equal(t, goodUsername, session.Username) require.Empty(t, session.Username)
require.Equal(t, goodSubject, session.Subject) require.Empty(t, session.Subject)
} }
func requireValidIDToken( func requireValidIDToken(
@ -1698,12 +1745,14 @@ func requireValidIDToken(
IssuedAt int64 `json:"iat"` IssuedAt int64 `json:"iat"`
RequestedAt int64 `json:"rat"` RequestedAt int64 `json:"rat"`
AuthTime int64 `json:"auth_time"` AuthTime int64 `json:"auth_time"`
Groups string `json:"groups"`
Username string `json:"username"`
} }
// Note that there is a bug in fosite which prevents the `at_hash` claim from appearing in this ID token // Note that there is a bug in fosite which prevents the `at_hash` claim from appearing in this ID token
// during the initial authcode exchange, but does not prevent `at_hash` from appearing in the refreshed ID token. // during the initial authcode exchange, but does not prevent `at_hash` from appearing in the refreshed ID token.
// We can add a workaround for this later. // We can add a workaround for this later.
idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat"} idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat", "groups", "username"}
if wantAtHashClaimInIDToken { if wantAtHashClaimInIDToken {
idTokenFields = append(idTokenFields, "at_hash") idTokenFields = append(idTokenFields, "at_hash")
} }
@ -1717,6 +1766,8 @@ func requireValidIDToken(
err := token.Claims(&claims) err := token.Claims(&claims)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, goodSubject, claims.Subject) require.Equal(t, goodSubject, claims.Subject)
require.Equal(t, goodUsername, claims.Username)
require.Equal(t, goodGroups, claims.Groups)
require.Len(t, claims.Audience, 1) require.Len(t, claims.Audience, 1)
require.Equal(t, goodClient, claims.Audience[0]) require.Equal(t, goodClient, claims.Audience[0])
require.Equal(t, goodIssuer, claims.Issuer) require.Equal(t, goodIssuer, claims.Issuer)