callback_handler.go: Add JWT Audience claim to storage

This commit is contained in:
Ryan Richard 2020-11-19 08:53:53 -08:00
parent ee84f31f42
commit a47617cad0
2 changed files with 21 additions and 10 deletions

View File

@ -72,7 +72,7 @@ func NewHandler(
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
r.Context(), r.Context(),
r.URL.Query().Get("code"), // TODO: do we need to validate this? authcode(r),
state.PKCECode, state.PKCECode,
state.Nonce, state.Nonce,
) )
@ -113,7 +113,7 @@ func NewHandler(
Claims: &jwt.IDTokenClaims{ Claims: &jwt.IDTokenClaims{
Issuer: downstreamIssuer, Issuer: downstreamIssuer,
Subject: username, Subject: username,
Audience: []string{"my-client"}, // TODO use the right value here Audience: []string{downstreamAuthParams.Get("client_id")},
ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here
IssuedAt: now, // TODO test this IssuedAt: now, // TODO test this
RequestedAt: now, // TODO test this RequestedAt: now, // TODO test this
@ -133,6 +133,10 @@ func NewHandler(
}) })
} }
func authcode(r *http.Request) string {
return r.FormValue("code")
}
func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
@ -144,7 +148,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return nil, err return nil, err
} }
if r.FormValue("code") == "" { if authcode(r) == "" {
plog.Info("code param not found") plog.Info("code param not found")
return nil, httperr.New(http.StatusBadRequest, "code param not found") return nil, httperr.New(http.StatusBadRequest, "code param not found")
} }

View File

@ -36,6 +36,12 @@ func TestCallbackEndpoint(t *testing.T) {
downstreamIssuer = "https://my-downstream-issuer.com/path" downstreamIssuer = "https://my-downstream-issuer.com/path"
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
happyUpstreamAuthcode = "upstream-auth-code" happyUpstreamAuthcode = "upstream-auth-code"
upstreamUsername = "test-pinniped-username"
downstreamClientID = "pinniped-cli"
)
var (
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
) )
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
@ -47,8 +53,8 @@ func TestCallbackEndpoint(t *testing.T) {
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) {
return oidcclient.Token{}, return oidcclient.Token{},
map[string]interface{}{ map[string]interface{}{
"the-user-claim": "test-pinniped-username", "the-user-claim": upstreamUsername,
"the-groups-claim": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, "the-groups-claim": upstreamGroupMembership,
"other-claim": "should be ignored", "other-claim": "should be ignored",
}, },
nil nil
@ -62,8 +68,8 @@ func TestCallbackEndpoint(t *testing.T) {
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) {
return oidcclient.Token{}, return oidcclient.Token{},
map[string]interface{}{ map[string]interface{}{
"sub": "test-pinniped-username", "sub": upstreamUsername,
"groups": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, "groups": upstreamGroupMembership,
"other-claim": "should be ignored", "other-claim": "should be ignored",
}, },
nil nil
@ -104,7 +110,7 @@ func TestCallbackEndpoint(t *testing.T) {
happyOriginalRequestParamsQuery := url.Values{ happyOriginalRequestParamsQuery := url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"openid profile email"}, "scope": []string{"openid profile email"},
"client_id": []string{"pinniped-cli"}, "client_id": []string{downstreamClientID},
"state": []string{happyDownstreamState}, "state": []string{happyDownstreamState},
"nonce": []string{"some-nonce-value"}, "nonce": []string{"some-nonce-value"},
"code_challenge": []string{"some-challenge"}, "code_challenge": []string{"some-challenge"},
@ -451,8 +457,9 @@ func TestCallbackEndpoint(t *testing.T) {
require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") require.NotContains(t, storedRequest.GetGrantedScopes(), "openid")
} }
require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer) require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer)
require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) require.Equal(t, upstreamUsername, storedSession.Claims.Subject)
require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience)
require.Equal(t, upstreamGroupMembership, storedSession.Claims.Extra["oidc.pinniped.dev/groups"])
} else { } else {
require.Empty(t, rsp.Header().Values("Location")) require.Empty(t, rsp.Header().Values("Location"))
} }