From 16dfab0aff3f21cc62ab6bb1006cbf1a79d71ed3 Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Mon, 14 Dec 2020 18:27:14 -0800 Subject: [PATCH] token_handler_test.go: Add tests for username and groups custom claims --- internal/oidc/callback/callback_handler.go | 27 ++----- internal/oidc/oidc.go | 15 ++++ internal/oidc/token/token_handler_test.go | 85 +++++++++++++++++----- 3 files changed, 88 insertions(+), 39 deletions(-) diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 3bdcd9ee..97731aec 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -23,23 +23,6 @@ import ( "go.pinniped.dev/internal/plog" ) -const ( - // The name of the issuer claim specified in the OIDC spec. - idTokenIssuerClaim = "iss" - - // The name of the subject claim specified in the OIDC spec. - idTokenSubjectClaim = "sub" - - // idTokenUsernameClaim is a custom claim in the downstream ID token - // whose value is mapped from a claim in the upstream token. - // By default the value is the same as the downstream subject claim's. - idTokenUsernameClaim = "username" - - // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token - // information. - downstreamGroupsClaim = "groups" -) - func NewHandler( idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, @@ -199,7 +182,7 @@ func getSubjectAndUsernameFromUpstreamIDToken( ) (string, string, error) { // The spec says the "sub" claim is only unique per issuer, // so we will prepend the issuer string to make it globally unique. - upstreamIssuer := idTokenClaims[idTokenIssuerClaim] + upstreamIssuer := idTokenClaims[oidc.IDTokenIssuerClaim] if upstreamIssuer == "" { plog.Warning( "issuer claim in upstream ID token missing", @@ -218,7 +201,7 @@ func getSubjectAndUsernameFromUpstreamIDToken( return "", "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format") } - subjectAsInterface, ok := idTokenClaims[idTokenSubjectClaim] + subjectAsInterface, ok := idTokenClaims[oidc.IDTokenSubjectClaim] if !ok { plog.Warning( "no subject claim in upstream ID token", @@ -236,7 +219,7 @@ func getSubjectAndUsernameFromUpstreamIDToken( return "", "", httperr.New(http.StatusUnprocessableEntity, "subject claim in upstream ID token has invalid format") } - subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, idTokenSubjectClaim, upstreamSubject) + subject := fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, upstreamSubject) usernameClaim := upstreamIDPConfig.GetUsernameClaim() if usernameClaim == "" { @@ -312,10 +295,10 @@ func makeDownstreamSession(subject string, username string, groups []string) *op }, } openIDSession.Claims.Extra = map[string]interface{}{ - idTokenUsernameClaim: username, + oidc.DownstreamUsernameClaim: username, } if groups != nil { - openIDSession.Claims.Extra[downstreamGroupsClaim] = groups + openIDSession.Claims.Extra[oidc.DownstreamGroupsClaim] = groups } return openIDSession } diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index a1cb56fe..d53ee663 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -44,6 +44,21 @@ const ( // CSRFCookieEncodingName is the `name` passed to the encoder for encoding and decoding the CSRF // cookie contents. CSRFCookieEncodingName = "csrf" + + // The name of the issuer claim specified in the OIDC spec. + IDTokenIssuerClaim = "iss" + + // The name of the subject claim specified in the OIDC spec. + IDTokenSubjectClaim = "sub" + + // DownstreamUsernameClaim is a custom claim in the downstream ID token + // whose value is mapped from a claim in the upstream token. + // By default the value is the same as the downstream subject claim's. + DownstreamUsernameClaim = "username" + + // DownstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token + // information. + DownstreamGroupsClaim = "groups" ) // Encoder is the encoding side of the securecookie.Codec interface. diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index f4610293..0f31e1dd 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -53,8 +53,9 @@ const ( goodRedirectURI = "http://127.0.0.1/callback" goodPKCECodeVerifier = "some-pkce-verifier-that-must-be-at-least-43-characters-to-meet-entropy-requirements" goodNonce = "some-nonce-value-with-enough-bytes-to-exceed-min-allowed" - goodSubject = "some-subject" + goodSubject = "https://issuer?sub=some-subject" goodUsername = "some-username" + goodGroups = "group1,groups2" hmacSecret = "this needs to be at least 32 characters to meet entropy requirements" @@ -781,11 +782,13 @@ func TestTokenExchange(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange) - var parsedResponseBody map[string]interface{} - require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedResponseBody)) + t.Parallel() - request := happyTokenExchangeRequest(test.requestedAudience, parsedResponseBody["access_token"].(string)) + subject, rsp, _, _, secrets, storage := exchangeAuthcodeForTokens(t, test.authcodeExchange) + var parsedAuthcodeExchangeResponseBody map[string]interface{} + require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &parsedAuthcodeExchangeResponseBody)) + + request := happyTokenExchangeRequest(test.requestedAudience, parsedAuthcodeExchangeResponseBody["access_token"].(string)) if test.modifyStorage != nil { test.modifyStorage(t, storage, request) } @@ -801,6 +804,10 @@ func TestTokenExchange(t *testing.T) { existingSecrets, err := secrets.List(context.Background(), metav1.ListOptions{}) require.NoError(t, err) + // Wait one second before performing the token exchange so we can see that the new ID token has new issued + // at and expires at dates which are newer than the old tokens. + time.Sleep(1 * time.Second) + subject.ServeHTTP(rsp, req) t.Logf("response: %#v", rsp) t.Logf("response body: %q", rsp.Body.String()) @@ -816,6 +823,12 @@ func TestTokenExchange(t *testing.T) { return } + claimsOfFirstIDToken := map[string]interface{}{} + originalIDToken := parsedAuthcodeExchangeResponseBody["id_token"].(string) + firstIDTokenDecoded, _ := josejwt.ParseSigned(originalIDToken) + err = firstIDTokenDecoded.UnsafeClaimsWithoutVerification(&claimsOfFirstIDToken) + require.NoError(t, err) + var responseBody map[string]interface{} require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &responseBody)) @@ -823,18 +836,43 @@ func TestTokenExchange(t *testing.T) { require.Equal(t, "N_A", responseBody["token_type"]) require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", responseBody["issued_token_type"]) - // Assert that the returned token has expected claims. + // Parse the returned token. parsedJWT, err := jose.ParseSigned(responseBody["access_token"].(string)) require.NoError(t, err) var tokenClaims map[string]interface{} require.NoError(t, json.Unmarshal(parsedJWT.UnsafePayloadWithoutVerification(), &tokenClaims)) - require.Contains(t, tokenClaims, "iat") - require.Contains(t, tokenClaims, "rat") - require.Contains(t, tokenClaims, "jti") + + // Make sure that these are the only fields in the token. + idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat", "groups", "username"} + require.ElementsMatch(t, idTokenFields, getMapKeys(tokenClaims)) + + // Assert that the returned token has expected claims values. + require.NotEmpty(t, tokenClaims["jti"]) + require.NotEmpty(t, tokenClaims["auth_time"]) + require.NotEmpty(t, tokenClaims["exp"]) + require.NotEmpty(t, tokenClaims["iat"]) + require.NotEmpty(t, tokenClaims["rat"]) + require.Empty(t, tokenClaims["nonce"]) // ID tokens only contain nonce during an authcode exchange require.Len(t, tokenClaims["aud"], 1) require.Contains(t, tokenClaims["aud"], test.requestedAudience) require.Equal(t, goodSubject, tokenClaims["sub"]) require.Equal(t, goodIssuer, tokenClaims["iss"]) + require.Equal(t, goodUsername, tokenClaims["username"]) + require.Equal(t, goodGroups, tokenClaims["groups"]) + + // Also assert that some are the same as the original downstream ID token. + requireClaimsAreEqual(t, "iss", claimsOfFirstIDToken, tokenClaims) // issuer + requireClaimsAreEqual(t, "sub", claimsOfFirstIDToken, tokenClaims) // subject + requireClaimsAreEqual(t, "rat", claimsOfFirstIDToken, tokenClaims) // requested at + requireClaimsAreEqual(t, "auth_time", claimsOfFirstIDToken, tokenClaims) // auth time + + // Also assert which are the different from the original downstream ID token. + requireClaimsAreNotEqual(t, "jti", claimsOfFirstIDToken, tokenClaims) // JWT ID + requireClaimsAreNotEqual(t, "aud", claimsOfFirstIDToken, tokenClaims) // audience + requireClaimsAreNotEqual(t, "exp", claimsOfFirstIDToken, tokenClaims) // expires at + require.Greater(t, tokenClaims["exp"], claimsOfFirstIDToken["exp"]) + requireClaimsAreNotEqual(t, "iat", claimsOfFirstIDToken, tokenClaims) // issued at + require.Greater(t, tokenClaims["iat"], claimsOfFirstIDToken["iat"]) // Assert that nothing in storage has been modified. newSecrets, err := secrets.List(context.Background(), metav1.ListOptions{}) @@ -1386,11 +1424,15 @@ func simulateAuthEndpointHavingAlreadyRun(t *testing.T, authRequest *http.Reques session := &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ Subject: goodSubject, - AuthTime: goodAuthTime, RequestedAt: goodRequestedAtTime, + AuthTime: goodAuthTime, + Extra: map[string]interface{}{ + oidc.DownstreamUsernameClaim: goodUsername, + oidc.DownstreamGroupsClaim: goodGroups, + }, }, - Subject: goodSubject, - Username: goodUsername, + Subject: "", // not used, note that callback_handler.go does not set this + Username: "", // not used, note that callback_handler.go does not set this } authRequester, err := oauthHelper.NewAuthorizeRequest(ctx, authRequest) require.NoError(t, err) @@ -1609,6 +1651,12 @@ func requireValidStoredRequest( require.Empty(t, claims.JTI) // When claims.JTI is empty, Fosite will generate a UUID for this field. require.Equal(t, goodSubject, claims.Subject) + // Our custom claims from the authorize endpoint should still be set. + require.Equal(t, map[string]interface{}{ + "username": goodUsername, + "groups": goodGroups, + }, claims.Extra) + // We are in charge of setting these fields. For the purpose of testing, we ensure that the // sentinel test value is set correctly. require.Equal(t, goodRequestedAtTime, claims.RequestedAt) @@ -1633,7 +1681,6 @@ func requireValidStoredRequest( require.Empty(t, claims.AuthenticationContextClassReference) require.Empty(t, claims.AuthenticationMethodsReference) require.Empty(t, claims.CodeHash) - require.Empty(t, claims.Extra) } // Assert that the session headers are what we think they should be. @@ -1664,9 +1711,9 @@ func requireValidStoredRequest( require.False(t, ok, "expected session to not hold expiration time for access token, but it did") } - // Assert that the session's username and subject are correct. - require.Equal(t, goodUsername, session.Username) - require.Equal(t, goodSubject, session.Subject) + // We don't use these, so they should be empty. + require.Empty(t, session.Username) + require.Empty(t, session.Subject) } func requireValidIDToken( @@ -1698,12 +1745,14 @@ func requireValidIDToken( IssuedAt int64 `json:"iat"` RequestedAt int64 `json:"rat"` AuthTime int64 `json:"auth_time"` + Groups string `json:"groups"` + Username string `json:"username"` } // Note that there is a bug in fosite which prevents the `at_hash` claim from appearing in this ID token // during the initial authcode exchange, but does not prevent `at_hash` from appearing in the refreshed ID token. // We can add a workaround for this later. - idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat"} + idTokenFields := []string{"sub", "aud", "iss", "jti", "nonce", "auth_time", "exp", "iat", "rat", "groups", "username"} if wantAtHashClaimInIDToken { idTokenFields = append(idTokenFields, "at_hash") } @@ -1717,6 +1766,8 @@ func requireValidIDToken( err := token.Claims(&claims) require.NoError(t, err) require.Equal(t, goodSubject, claims.Subject) + require.Equal(t, goodUsername, claims.Username) + require.Equal(t, goodGroups, claims.Groups) require.Len(t, claims.Audience, 1) require.Equal(t, goodClient, claims.Audience[0]) require.Equal(t, goodIssuer, claims.Issuer)