token_handler_test.go: Add tests for username and groups custom claims
This commit is contained in:
parent
afcd5e3e36
commit
16dfab0aff
@ -23,23 +23,6 @@ import (
|
|||||||
"go.pinniped.dev/internal/plog"
|
"go.pinniped.dev/internal/plog"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
// The name of the issuer claim specified in the OIDC spec.
|
|
||||||
idTokenIssuerClaim = "iss"
|
|
||||||
|
|
||||||
// The name of the subject claim specified in the OIDC spec.
|
|
||||||
idTokenSubjectClaim = "sub"
|
|
||||||
|
|
||||||
// idTokenUsernameClaim is a custom claim in the downstream ID token
|
|
||||||
// whose value is mapped from a claim in the upstream token.
|
|
||||||
// By default the value is the same as the downstream subject claim's.
|
|
||||||
idTokenUsernameClaim = "username"
|
|
||||||
|
|
||||||
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
|
|
||||||
// information.
|
|
||||||
downstreamGroupsClaim = "groups"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewHandler(
|
func NewHandler(
|
||||||
idpListGetter oidc.IDPListGetter,
|
idpListGetter oidc.IDPListGetter,
|
||||||
oauthHelper fosite.OAuth2Provider,
|
oauthHelper fosite.OAuth2Provider,
|
||||||
@ -199,7 +182,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
|
|||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
// The spec says the "sub" claim is only unique per issuer,
|
// The spec says the "sub" claim is only unique per issuer,
|
||||||
// so we will prepend the issuer string to make it globally unique.
|
// so we will prepend the issuer string to make it globally unique.
|
||||||
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
|
upstreamIssuer := idTokenClaims[oidc.IDTokenIssuerClaim]
|
||||||
if upstreamIssuer == "" {
|
if upstreamIssuer == "" {
|
||||||
plog.Warning(
|
plog.Warning(
|
||||||
"issuer claim in upstream ID token missing",
|
"issuer claim in upstream ID token missing",
|
||||||
@ -218,7 +201,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
|
|||||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
|
||||||
}
|
}
|
||||||
|
|
||||||
subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim]
|
subjectAsInterface, ok := idTokenClaims[oidc.IDTokenSubjectClaim]
|
||||||
if !ok {
|
if !ok {
|
||||||
plog.Warning(
|
plog.Warning(
|
||||||
"no subject claim in upstream ID token",
|
"no subject claim in upstream ID token",
|
||||||
@ -236,7 +219,7 @@ func getSubjectAndUsernameFromUpstreamIDToken(
|
|||||||
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
|
return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format")
|
||||||
}
|
}
|
||||||
|
|
||||||
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject)
|
subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, upstreamSubject)
|
||||||
|
|
||||||
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
|
||||||
if usernameClaim == "" {
|
if usernameClaim == "" {
|
||||||
@ -312,10 +295,10 @@ func makeDownstreamSession(subject string, username string, groups []string) *op
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
openIDSession.Claims.Extra = map[string]interface{}{
|
openIDSession.Claims.Extra = map[string]interface{}{
|
||||||
idTokenUsernameClaim: username,
|
oidc.DownstreamUsernameClaim: username,
|
||||||
}
|
}
|
||||||
if groups != nil {
|
if groups != nil {
|
||||||
openIDSession.Claims.Extra[downstreamGroupsClaim] = groups
|
openIDSession.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
|
||||||
}
|
}
|
||||||
return openIDSession
|
return openIDSession
|
||||||
}
|
}
|
||||||
|
@ -44,6 +44,21 @@ const (
|
|||||||
// CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF
|
// CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF
|
||||||
// cookie contents.
|
// cookie contents.
|
||||||
CSRFCookieEncodingName = "csrf"
|
CSRFCookieEncodingName = "csrf"
|
||||||
|
|
||||||
|
// The name of the issuer claim specified in the OIDC spec.
|
||||||
|
IDTokenIssuerClaim = "iss"
|
||||||
|
|
||||||
|
// The name of the subject claim specified in the OIDC spec.
|
||||||
|
IDTokenSubjectClaim = "sub"
|
||||||
|
|
||||||
|
// DownstreamUsernameClaim is a custom claim in the downstream ID token
|
||||||
|
// whose value is mapped from a claim in the upstream token.
|
||||||
|
// By default the value is the same as the downstream subject claim's.
|
||||||
|
DownstreamUsernameClaim = "username"
|
||||||
|
|
||||||
|
// DownstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
|
||||||
|
// information.
|
||||||
|
DownstreamGroupsClaim = "groups"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Encoder is the encoding side of the securecookie.Codec interface.
|
// Encoder is the encoding side of the securecookie.Codec interface.
|
||||||
|
@ -53,8 +53,9 @@ const (
|
|||||||
goodRedirectURI = "http://127.0.0.1/callback"
|
goodRedirectURI = "http://127.0.0.1/callback"
|
||||||
goodPKCECodeVerifier = "some-pkce-verifier-that-must-be-at-least-43-characters-to-meet-entropy-requirements"
|
goodPKCECodeVerifier = "some-pkce-verifier-that-must-be-at-least-43-characters-to-meet-entropy-requirements"
|
||||||
goodNonce = "some-nonce-value-with-enough-bytes-to-exceed-min-allowed"
|
goodNonce = "some-nonce-value-with-enough-bytes-to-exceed-min-allowed"
|
||||||
goodSubject = "some-subject"
|
goodSubject = "https://issuer?sub=some-subject"
|
||||||
goodUsername = "some-username"
|
goodUsername = "some-username"
|
||||||
|
goodGroups = "group1,groups2"
|
||||||
|
|
||||||
hmacSecret = "this needs to be at least 32 characters to meet entropy requirements"
|
hmacSecret = "this needs to be at least 32 characters to meet entropy requirements"
|
||||||
|
|
||||||
@ -781,11 +782,13 @@ func TestTokenExchange(t *testing.T) {
|
|||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
test := test
|
test := test
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange)
|
t.Parallel()
|
||||||
var parsedResponseBody map[string]interface{}
|
|
||||||
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody))
|
|
||||||
|
|
||||||
request := happyTokenExchangeRequest(test.requestedAudience, parsedResponseBody["access_token"].(string))
|
subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange)
|
||||||
|
var parsedAuthcodeExchangeResponseBody map[string]interface{}
|
||||||
|
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody))
|
||||||
|
|
||||||
|
request := happyTokenExchangeRequest(test.requestedAudience, parsedAuthcodeExchangeResponseBody["access_token"].(string))
|
||||||
if test.modifyStorage != nil {
|
if test.modifyStorage != nil {
|
||||||
test.modifyStorage(t, storage, request)
|
test.modifyStorage(t, storage, request)
|
||||||
}
|
}
|
||||||
@ -801,6 +804,10 @@ func TestTokenExchange(t *testing.T) {
|
|||||||
existingSecrets, err := secrets.List(context.Background(), metav1.ListOptions{})
|
existingSecrets, err := secrets.List(context.Background(), metav1.ListOptions{})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait one second before performing the token exchange so we can see that the new ID token has new issued
|
||||||
|
// at and expires at dates which are newer than the old tokens.
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
subject.ServeHTTP(rsp, req)
|
subject.ServeHTTP(rsp, req)
|
||||||
t.Logf("response: %#v", rsp)
|
t.Logf("response: %#v", rsp)
|
||||||
t.Logf("response body: %q", rsp.Body.String())
|
t.Logf("response body: %q", rsp.Body.String())
|
||||||
@ -816,6 +823,12 @@ func TestTokenExchange(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
claimsOfFirstIDToken := map[string]interface{}{}
|
||||||
|
originalIDToken := parsedAuthcodeExchangeResponseBody["id_token"].(string)
|
||||||
|
firstIDTokenDecoded, _ := josejwt.ParseSigned(originalIDToken)
|
||||||
|
err = firstIDTokenDecoded.UnsafeClaimsWithoutVerification(&claimsOfFirstIDToken)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
var responseBody map[string]interface{}
|
var responseBody map[string]interface{}
|
||||||
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &responseBody))
|
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &responseBody))
|
||||||
|
|
||||||
@ -823,18 +836,43 @@ func TestTokenExchange(t *testing.T) {
|
|||||||
require.Equal(t, "N_A", responseBody["token_type"])
|
require.Equal(t, "N_A", responseBody["token_type"])
|
||||||
require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", responseBody["issued_token_type"])
|
require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", responseBody["issued_token_type"])
|
||||||
|
|
||||||
// Assert that the returned token has expected claims.
|
// Parse the returned token.
|
||||||
parsedJWT, err := jose.ParseSigned(responseBody["access_token"].(string))
|
parsedJWT, err := jose.ParseSigned(responseBody["access_token"].(string))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
var tokenClaims map[string]interface{}
|
var tokenClaims map[string]interface{}
|
||||||
require.NoError(t, json.Unmarshal(parsedJWT.UnsafePayloadWithoutVerification(), &tokenClaims))
|
require.NoError(t, json.Unmarshal(parsedJWT.UnsafePayloadWithoutVerification(), &tokenClaims))
|
||||||
require.Contains(t, tokenClaims, "iat")
|
|
||||||
require.Contains(t, tokenClaims, "rat")
|
// Make sure that these are the only fields in the token.
|
||||||
require.Contains(t, tokenClaims, "jti")
|
idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat", "groups", "username"}
|
||||||
|
require.ElementsMatch(t, idTokenFields, getMapKeys(tokenClaims))
|
||||||
|
|
||||||
|
// Assert that the returned token has expected claims values.
|
||||||
|
require.NotEmpty(t, tokenClaims["jti"])
|
||||||
|
require.NotEmpty(t, tokenClaims["auth_time"])
|
||||||
|
require.NotEmpty(t, tokenClaims["exp"])
|
||||||
|
require.NotEmpty(t, tokenClaims["iat"])
|
||||||
|
require.NotEmpty(t, tokenClaims["rat"])
|
||||||
|
require.Empty(t, tokenClaims["nonce"]) // ID tokens only contain nonce during an authcode exchange
|
||||||
require.Len(t, tokenClaims["aud"], 1)
|
require.Len(t, tokenClaims["aud"], 1)
|
||||||
require.Contains(t, tokenClaims["aud"], test.requestedAudience)
|
require.Contains(t, tokenClaims["aud"], test.requestedAudience)
|
||||||
require.Equal(t, goodSubject, tokenClaims["sub"])
|
require.Equal(t, goodSubject, tokenClaims["sub"])
|
||||||
require.Equal(t, goodIssuer, tokenClaims["iss"])
|
require.Equal(t, goodIssuer, tokenClaims["iss"])
|
||||||
|
require.Equal(t, goodUsername, tokenClaims["username"])
|
||||||
|
require.Equal(t, goodGroups, tokenClaims["groups"])
|
||||||
|
|
||||||
|
// Also assert that some are the same as the original downstream ID token.
|
||||||
|
requireClaimsAreEqual(t, "iss", claimsOfFirstIDToken, tokenClaims) // issuer
|
||||||
|
requireClaimsAreEqual(t, "sub", claimsOfFirstIDToken, tokenClaims) // subject
|
||||||
|
requireClaimsAreEqual(t, "rat", claimsOfFirstIDToken, tokenClaims) // requested at
|
||||||
|
requireClaimsAreEqual(t, "auth_time", claimsOfFirstIDToken, tokenClaims) // auth time
|
||||||
|
|
||||||
|
// Also assert which are the different from the original downstream ID token.
|
||||||
|
requireClaimsAreNotEqual(t, "jti", claimsOfFirstIDToken, tokenClaims) // JWT ID
|
||||||
|
requireClaimsAreNotEqual(t, "aud", claimsOfFirstIDToken, tokenClaims) // audience
|
||||||
|
requireClaimsAreNotEqual(t, "exp", claimsOfFirstIDToken, tokenClaims) // expires at
|
||||||
|
require.Greater(t, tokenClaims["exp"], claimsOfFirstIDToken["exp"])
|
||||||
|
requireClaimsAreNotEqual(t, "iat", claimsOfFirstIDToken, tokenClaims) // issued at
|
||||||
|
require.Greater(t, tokenClaims["iat"], claimsOfFirstIDToken["iat"])
|
||||||
|
|
||||||
// Assert that nothing in storage has been modified.
|
// Assert that nothing in storage has been modified.
|
||||||
newSecrets, err := secrets.List(context.Background(), metav1.ListOptions{})
|
newSecrets, err := secrets.List(context.Background(), metav1.ListOptions{})
|
||||||
@ -1386,11 +1424,15 @@ func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Reques
|
|||||||
session := &openid.DefaultSession{
|
session := &openid.DefaultSession{
|
||||||
Claims: &jwt.IDTokenClaims{
|
Claims: &jwt.IDTokenClaims{
|
||||||
Subject: goodSubject,
|
Subject: goodSubject,
|
||||||
AuthTime: goodAuthTime,
|
|
||||||
RequestedAt: goodRequestedAtTime,
|
RequestedAt: goodRequestedAtTime,
|
||||||
|
AuthTime: goodAuthTime,
|
||||||
|
Extra: map[string]interface{}{
|
||||||
|
oidc.DownstreamUsernameClaim: goodUsername,
|
||||||
|
oidc.DownstreamGroupsClaim: goodGroups,
|
||||||
},
|
},
|
||||||
Subject: goodSubject,
|
},
|
||||||
Username: goodUsername,
|
Subject: "", // not used, note that callback_handler.go does not set this
|
||||||
|
Username: "", // not used, note that callback_handler.go does not set this
|
||||||
}
|
}
|
||||||
authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest)
|
authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -1609,6 +1651,12 @@ func requireValidStoredRequest(
|
|||||||
require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field.
|
require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field.
|
||||||
require.Equal(t, goodSubject, claims.Subject)
|
require.Equal(t, goodSubject, claims.Subject)
|
||||||
|
|
||||||
|
// Our custom claims from the authorize endpoint should still be set.
|
||||||
|
require.Equal(t, map[string]interface{}{
|
||||||
|
"username": goodUsername,
|
||||||
|
"groups": goodGroups,
|
||||||
|
}, claims.Extra)
|
||||||
|
|
||||||
// We are in charge of setting these fields. For the purpose of testing, we ensure that the
|
// We are in charge of setting these fields. For the purpose of testing, we ensure that the
|
||||||
// sentinel test value is set correctly.
|
// sentinel test value is set correctly.
|
||||||
require.Equal(t, goodRequestedAtTime, claims.RequestedAt)
|
require.Equal(t, goodRequestedAtTime, claims.RequestedAt)
|
||||||
@ -1633,7 +1681,6 @@ func requireValidStoredRequest(
|
|||||||
require.Empty(t, claims.AuthenticationContextClassReference)
|
require.Empty(t, claims.AuthenticationContextClassReference)
|
||||||
require.Empty(t, claims.AuthenticationMethodsReference)
|
require.Empty(t, claims.AuthenticationMethodsReference)
|
||||||
require.Empty(t, claims.CodeHash)
|
require.Empty(t, claims.CodeHash)
|
||||||
require.Empty(t, claims.Extra)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that the session headers are what we think they should be.
|
// Assert that the session headers are what we think they should be.
|
||||||
@ -1664,9 +1711,9 @@ func requireValidStoredRequest(
|
|||||||
require.False(t, ok, "expected session to not hold expiration time for access token, but it did")
|
require.False(t, ok, "expected session to not hold expiration time for access token, but it did")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assert that the session's username and subject are correct.
|
// We don't use these, so they should be empty.
|
||||||
require.Equal(t, goodUsername, session.Username)
|
require.Empty(t, session.Username)
|
||||||
require.Equal(t, goodSubject, session.Subject)
|
require.Empty(t, session.Subject)
|
||||||
}
|
}
|
||||||
|
|
||||||
func requireValidIDToken(
|
func requireValidIDToken(
|
||||||
@ -1698,12 +1745,14 @@ func requireValidIDToken(
|
|||||||
IssuedAt int64 `json:"iat"`
|
IssuedAt int64 `json:"iat"`
|
||||||
RequestedAt int64 `json:"rat"`
|
RequestedAt int64 `json:"rat"`
|
||||||
AuthTime int64 `json:"auth_time"`
|
AuthTime int64 `json:"auth_time"`
|
||||||
|
Groups string `json:"groups"`
|
||||||
|
Username string `json:"username"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that there is a bug in fosite which prevents the `at_hash` claim from appearing in this ID token
|
// Note that there is a bug in fosite which prevents the `at_hash` claim from appearing in this ID token
|
||||||
// during the initial authcode exchange, but does not prevent `at_hash` from appearing in the refreshed ID token.
|
// during the initial authcode exchange, but does not prevent `at_hash` from appearing in the refreshed ID token.
|
||||||
// We can add a workaround for this later.
|
// We can add a workaround for this later.
|
||||||
idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat"}
|
idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat", "groups", "username"}
|
||||||
if wantAtHashClaimInIDToken {
|
if wantAtHashClaimInIDToken {
|
||||||
idTokenFields = append(idTokenFields, "at_hash")
|
idTokenFields = append(idTokenFields, "at_hash")
|
||||||
}
|
}
|
||||||
@ -1717,6 +1766,8 @@ func requireValidIDToken(
|
|||||||
err := token.Claims(&claims)
|
err := token.Claims(&claims)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, goodSubject, claims.Subject)
|
require.Equal(t, goodSubject, claims.Subject)
|
||||||
|
require.Equal(t, goodUsername, claims.Username)
|
||||||
|
require.Equal(t, goodGroups, claims.Groups)
|
||||||
require.Len(t, claims.Audience, 1)
|
require.Len(t, claims.Audience, 1)
|
||||||
require.Equal(t, goodClient, claims.Audience[0])
|
require.Equal(t, goodClient, claims.Audience[0])
|
||||||
require.Equal(t, goodIssuer, claims.Issuer)
|
require.Equal(t, goodIssuer, claims.Issuer)
|
||||||
|
Loading…
Reference in New Issue
Block a user