Assert more cluster-scoped ID token claims in supervisor_login_test.go

This commit is contained in:
Ryan Richard 2023-01-17 13:10:51 -08:00
parent 6156fdf175
commit 74c3156059

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package integration package integration
@ -2071,7 +2071,7 @@ func testSupervisorLogin(
if len(wantDownstreamIDTokenAdditionalClaims) > 0 { if len(wantDownstreamIDTokenAdditionalClaims) > 0 {
expectedIDTokenClaims = append(expectedIDTokenClaims, "additionalClaims") expectedIDTokenClaims = append(expectedIDTokenClaims, "additionalClaims")
} }
verifyTokenResponse( initialIDTokenClaims := verifyTokenResponse(
t, t,
tokenResponse, tokenResponse,
discovery, discovery,
@ -2088,7 +2088,7 @@ func testSupervisorLogin(
if requestTokenExchangeAud == "" { if requestTokenExchangeAud == "" {
requestTokenExchangeAud = "some-cluster-123" // use a default test value requestTokenExchangeAud = "some-cluster-123" // use a default test value
} }
doTokenExchange(t, requestTokenExchangeAud, &downstreamOAuth2Config, tokenResponse, httpClient, discovery, wantTokenExchangeResponse) doTokenExchange(t, requestTokenExchangeAud, &downstreamOAuth2Config, tokenResponse, httpClient, discovery, wantTokenExchangeResponse, initialIDTokenClaims)
wantRefreshedGroups := wantDownstreamIDTokenGroups wantRefreshedGroups := wantDownstreamIDTokenGroups
if editRefreshSessionDataWithoutBreaking != nil { if editRefreshSessionDataWithoutBreaking != nil {
@ -2131,7 +2131,7 @@ func testSupervisorLogin(
if len(wantDownstreamIDTokenAdditionalClaims) > 0 { if len(wantDownstreamIDTokenAdditionalClaims) > 0 {
expectRefreshedIDTokenClaims = append(expectRefreshedIDTokenClaims, "additionalClaims") expectRefreshedIDTokenClaims = append(expectRefreshedIDTokenClaims, "additionalClaims")
} }
verifyTokenResponse( refreshedIDTokenClaims := verifyTokenResponse(
t, t,
refreshedTokenResponse, refreshedTokenResponse,
discovery, discovery,
@ -2149,7 +2149,7 @@ func testSupervisorLogin(
require.NotEqual(t, tokenResponse.Extra("id_token"), refreshedTokenResponse.Extra("id_token")) require.NotEqual(t, tokenResponse.Extra("id_token"), refreshedTokenResponse.Extra("id_token"))
// token exchange on the refreshed token // token exchange on the refreshed token
doTokenExchange(t, requestTokenExchangeAud, &downstreamOAuth2Config, refreshedTokenResponse, httpClient, discovery, wantTokenExchangeResponse) doTokenExchange(t, requestTokenExchangeAud, &downstreamOAuth2Config, refreshedTokenResponse, httpClient, discovery, wantTokenExchangeResponse, refreshedIDTokenClaims)
// Now that we have successfully performed a refresh, let's test what happens when an // Now that we have successfully performed a refresh, let's test what happens when an
// upstream refresh fails during the next downstream refresh. // upstream refresh fails during the next downstream refresh.
@ -2206,7 +2206,7 @@ func verifyTokenResponse(
wantDownstreamIDTokenSubjectToMatch, wantDownstreamIDTokenUsernameToMatch string, wantDownstreamIDTokenSubjectToMatch, wantDownstreamIDTokenUsernameToMatch string,
wantDownstreamIDTokenGroups []string, wantDownstreamIDTokenGroups []string,
wantDownstreamIDTokenAdditionalClaims map[string]interface{}, wantDownstreamIDTokenAdditionalClaims map[string]interface{},
) { ) map[string]interface{} {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
@ -2274,6 +2274,8 @@ func verifyTokenResponse(
actualAccessTokenHashClaimValue := idTokenClaims["at_hash"] actualAccessTokenHashClaimValue := idTokenClaims["at_hash"]
require.NotEmpty(t, actualAccessTokenHashClaimValue) require.NotEmpty(t, actualAccessTokenHashClaimValue)
require.Equal(t, hashAccessToken(tokenResponse.AccessToken), actualAccessTokenHashClaimValue) require.Equal(t, hashAccessToken(tokenResponse.AccessToken), actualAccessTokenHashClaimValue)
return idTokenClaims
} }
func hashAccessToken(accessToken string) string { func hashAccessToken(accessToken string) string {
@ -2486,6 +2488,7 @@ func doTokenExchange(
httpClient *http.Client, httpClient *http.Client,
provider *coreosoidc.Provider, provider *coreosoidc.Provider,
wantTokenExchangeResponse func(t *testing.T, status int, body string), wantTokenExchangeResponse func(t *testing.T, status int, body string),
previousIDTokenClaims map[string]interface{},
) { ) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
@ -2529,6 +2532,7 @@ func doTokenExchange(
} }
require.NoError(t, json.NewDecoder(resp.Body).Decode(&respBody)) require.NoError(t, json.NewDecoder(resp.Body).Decode(&respBody))
// Note that this validates the "aud" claim, among other things.
var clusterVerifier = provider.Verifier(&coreosoidc.Config{ClientID: requestTokenExchangeAud}) var clusterVerifier = provider.Verifier(&coreosoidc.Config{ClientID: requestTokenExchangeAud})
exchangedToken, err := clusterVerifier.Verify(ctx, respBody.AccessToken) exchangedToken, err := clusterVerifier.Verify(ctx, respBody.AccessToken)
require.NoError(t, err) require.NoError(t, err)
@ -2539,6 +2543,18 @@ func doTokenExchange(
require.NoError(t, err) require.NoError(t, err)
t.Logf("exchanged token claims:\n%s", string(indentedClaims)) t.Logf("exchanged token claims:\n%s", string(indentedClaims))
// Some claims should be identical to the previously issued ID token.
require.Equal(t, previousIDTokenClaims["iss"], claims["iss"])
require.Equal(t, previousIDTokenClaims["sub"], claims["sub"])
require.Equal(t, previousIDTokenClaims["username"], claims["username"])
require.Equal(t, previousIDTokenClaims["groups"], claims["groups"]) // may be nil in some test cases
require.Equal(t, previousIDTokenClaims["additionalClaims"], claims["additionalClaims"]) // may be nil in some test cases
require.Equal(t, previousIDTokenClaims["auth_time"], claims["auth_time"])
require.Contains(t, claims, "rat") // requested at
require.Contains(t, claims, "iat") // issued at
require.Contains(t, claims, "exp") // expires at
require.Contains(t, claims, "jti") // JWT ID
// The original client ID should be preserved in the azp claim, therefore preserving this information // The original client ID should be preserved in the azp claim, therefore preserving this information
// about the original source of the authorization for tracing/auditing purposes, since the "aud" claim // about the original source of the authorization for tracing/auditing purposes, since the "aud" claim
// has been updated to have a new value. // has been updated to have a new value.