callback_handler.go: start to test upstream token corner cases

Also refactor to get rid of duplicate test structs.

Also also don't default groups ID token claim because there is no standard one.

Also also also add some logging that will hopefully help us in debugging in the
future.

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Ryan Richard 2020-11-19 14:19:01 -05:00 committed by Andrew Keesler
parent a47617cad0
commit 83101eefce
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
2 changed files with 189 additions and 114 deletions

View File

@ -26,14 +26,9 @@ const (
// ID token if the upstream OIDC IDP did not tell us to use another claim.
defaultUpstreamUsernameClaim = "sub"
// defaultUpstreamGroupsClaim is what we will use to extract the groups from an upstream OIDC ID
// token if the upstream OIDC IDP did not tell us to use another claim.
defaultUpstreamGroupsClaim = "groups"
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
// TODO: should this be per-issuer? Or per version?
downstreamGroupsClaim = "oidc.pinniped.dev/groups"
downstreamGroupsClaim = "groups"
)
func NewHandler(
@ -56,6 +51,7 @@ func NewHandler(
downstreamAuthParams, err := url.ParseQuery(state.AuthParams)
if err != nil {
plog.Error("error reading state downstream auth params", err)
return httperr.New(http.StatusBadRequest, "error reading state downstream auth params")
}
@ -63,11 +59,11 @@ func NewHandler(
reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams}
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest)
if err != nil {
plog.Error("error using state downstream auth params", err)
return httperr.New(http.StatusBadRequest, "error using state downstream auth params")
}
// Grant the openid scope only if it was requested.
// TODO: shouldn't we be potentially granting more scopes than just openid...
grantOpenIDScopeIfRequested(authorizeRequester)
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
@ -77,52 +73,18 @@ func NewHandler(
state.Nonce,
)
if err != nil {
plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName())
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
}
var username string
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
if usernameClaim == "" {
usernameClaim = defaultUpstreamUsernameClaim
}
usernameAsInterface, ok := idTokenClaims[usernameClaim]
if !ok {
panic(err) // TODO
}
username, ok = usernameAsInterface.(string)
if !ok {
panic(err) // TODO
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
if err != nil {
return err
}
var groups []string
groupsClaim := upstreamIDPConfig.GetGroupsClaim()
if groupsClaim == "" {
groupsClaim = defaultUpstreamGroupsClaim
}
groupsAsInterface, ok := idTokenClaims[groupsClaim]
if !ok {
panic(err) // TODO
}
groups, ok = groupsAsInterface.([]string)
if !ok {
panic(err) // TODO
}
now := time.Now()
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Issuer: downstreamIssuer,
Subject: username,
Audience: []string{downstreamAuthParams.Get("client_id")},
ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here
IssuedAt: now, // TODO test this
RequestedAt: now, // TODO test this
AuthTime: now, // TODO test this
Extra: map[string]interface{}{
downstreamGroupsClaim: groups,
},
},
})
groups := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups)
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
if err != nil {
panic(err) // TODO
}
@ -222,3 +184,76 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
}
}
}
func getUsernameFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) (string, error) {
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
if usernameClaim == "" {
// TODO: if we use the default "sub" claim, maybe we should create the username with the issuer
// since the spec says the "sub" claim is only unique per issuer.
usernameClaim = defaultUpstreamUsernameClaim
}
usernameAsInterface, ok := idTokenClaims[usernameClaim]
if !ok {
plog.Warning(
"no username claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim,
)
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
}
username, ok := usernameAsInterface.(string)
if !ok {
panic("todo bbb") // TODO
}
return username, nil
}
func getGroupsFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) []string {
groupsClaim := upstreamIDPConfig.GetGroupsClaim()
if groupsClaim == "" {
return nil
}
groupsAsInterface, ok := idTokenClaims[groupsClaim]
if !ok {
panic("todo ccc") // TODO
}
groups, ok := groupsAsInterface.([]string)
if !ok {
panic("todo ddd") // TODO
}
return groups
}
func makeDownstreamSession(issuer, clientID, username string, groups []string) *openid.DefaultSession {
now := time.Now()
openIDSession := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Issuer: issuer,
Subject: username,
Audience: []string{clientID},
ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here
IssuedAt: now, // TODO test this
RequestedAt: now, // TODO test this
AuthTime: now, // TODO test this
},
}
if groups != nil {
openIDSession.Claims.Extra = map[string]interface{}{
downstreamGroupsClaim: groups,
}
}
return openIDSession
}

View File

@ -29,6 +29,16 @@ import (
const (
happyUpstreamIDPName = "upstream-idp-name"
upstreamSubject = "abc123-some-guid"
upstreamUsername = "test-pinniped-username"
upstreamUsernameClaim = "the-user-claim"
upstreamGroupsClaim = "the-groups-claim"
)
var (
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
)
func TestCallbackEndpoint(t *testing.T) {
@ -36,63 +46,15 @@ func TestCallbackEndpoint(t *testing.T) {
downstreamIssuer = "https://my-downstream-issuer.com/path"
downstreamRedirectURI = "http://127.0.0.1/callback"
happyUpstreamAuthcode = "upstream-auth-code"
upstreamUsername = "test-pinniped-username"
downstreamClientID = "pinniped-cli"
)
var (
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
)
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
UsernameClaim: "the-user-claim",
GroupsClaim: "the-groups-claim",
Scopes: []string{"scope1", "scope2"},
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) {
return oidcclient.Token{},
map[string]interface{}{
"the-user-claim": upstreamUsername,
"the-groups-claim": upstreamGroupMembership,
"other-claim": "should be ignored",
},
nil
},
}
defaultClaimsUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
Scopes: []string{"scope1", "scope2"},
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) {
return oidcclient.Token{},
map[string]interface{}{
"sub": upstreamUsername,
"groups": upstreamGroupMembership,
"other-claim": "should be ignored",
},
nil
},
}
otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: "other-upstream-idp-name",
ClientID: "other-some-client-id",
Scopes: []string{"other-scope1", "other-scope2"},
}
failedExchangeUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: upstreamOIDCIdentityProvider.ClientID,
UsernameClaim: upstreamOIDCIdentityProvider.UsernameClaim,
GroupsClaim: upstreamOIDCIdentityProvider.GroupsClaim,
Scopes: upstreamOIDCIdentityProvider.Scopes,
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) {
return oidcclient.Token{}, nil, errors.New("some exchange error")
},
}
var stateEncoderHashKey = []byte("fake-hash-secret")
var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
var cookieEncoderHashKey = []byte("fake-hash-secret2")
@ -210,16 +172,18 @@ func TestCallbackEndpoint(t *testing.T) {
path string
csrfCookie string
wantStatus int
wantBody string
wantRedirectLocationRegexp string
wantGrantedOpenidScope bool
wantStatus int
wantBody string
wantRedirectLocationRegexp string
wantGrantedOpenidScope bool
wantDownstreamIDTokenSubject string
wantDownstreamIDTokenGroups []string
wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs
}{
{
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie,
@ -227,11 +191,13 @@ func TestCallbackEndpoint(t *testing.T) {
wantRedirectLocationRegexp: happyRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream IDP uses default claims",
idp: defaultClaimsUpstreamOIDCIdentityProvider,
name: "upstream IDP provides no username or group claim, so we use default username claim and skip groups",
idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie,
@ -239,6 +205,7 @@ func TestCallbackEndpoint(t *testing.T) {
wantRedirectLocationRegexp: happyRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamSubject,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
// TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes)
@ -290,7 +257,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState("this-will-not-decode").String(),
csrfCookie: happyCSRFCookie,
@ -299,7 +266,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "state's internal version does not match what we want",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(wrongVersionState).String(),
csrfCookie: happyCSRFCookie,
@ -308,7 +275,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "state's downstream auth params element is invalid",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(),
csrfCookie: happyCSRFCookie,
@ -317,7 +284,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "state's downstream auth params are missing required value (e.g., client_id)",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(missingClientIDState).String(),
csrfCookie: happyCSRFCookie,
@ -326,12 +293,14 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "state's downstream auth params does not contain openid scope",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
@ -345,7 +314,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "the CSRF cookie does not exist on request",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
wantStatus: http.StatusForbidden,
@ -353,7 +322,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
@ -362,7 +331,7 @@ func TestCallbackEndpoint(t *testing.T) {
},
{
name: "cookie csrf value does not match state csrf value",
idp: upstreamOIDCIdentityProvider,
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(wrongCSRFValueState).String(),
csrfCookie: happyCSRFCookie,
@ -373,7 +342,7 @@ func TestCallbackEndpoint(t *testing.T) {
// Upstream exchange
{
name: "upstream auth code exchange fails",
idp: failedExchangeUpstreamOIDCIdentityProvider,
idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie,
@ -381,6 +350,16 @@ func TestCallbackEndpoint(t *testing.T) {
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token does not contain requested username claim",
idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: no username claim in upstream ID token\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
}
for _, test := range tests {
test := test
@ -457,9 +436,13 @@ func TestCallbackEndpoint(t *testing.T) {
require.NotContains(t, storedRequest.GetGrantedScopes(), "openid")
}
require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer)
require.Equal(t, upstreamUsername, storedSession.Claims.Subject)
require.Equal(t, test.wantDownstreamIDTokenSubject, storedSession.Claims.Subject)
require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience)
require.Equal(t, upstreamGroupMembership, storedSession.Claims.Extra["oidc.pinniped.dev/groups"])
if test.wantDownstreamIDTokenGroups != nil {
require.Equal(t, test.wantDownstreamIDTokenGroups, storedSession.Claims.Extra["groups"])
} else {
require.NotContains(t, storedSession.Claims.Extra, "groups")
}
} else {
require.Empty(t, rsp.Header().Values("Location"))
}
@ -519,6 +502,63 @@ func (r *requestPath) String() string {
return path + params.Encode()
}
type upstreamOIDCIdentityProviderBuilder struct {
idToken map[string]interface{}
usernameClaim, groupsClaim string
authcodeExchangeErr error
}
func happyUpstream() *upstreamOIDCIdentityProviderBuilder {
return &upstreamOIDCIdentityProviderBuilder{
usernameClaim: upstreamUsernameClaim,
groupsClaim: upstreamGroupsClaim,
idToken: map[string]interface{}{
"sub": upstreamSubject,
upstreamUsernameClaim: upstreamUsername,
upstreamGroupsClaim: upstreamGroupMembership,
"other-claim": "should be ignored",
},
}
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder {
u.usernameClaim = ""
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *upstreamOIDCIdentityProviderBuilder {
u.groupsClaim = ""
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name, value string) *upstreamOIDCIdentityProviderBuilder {
u.idToken[name] = value
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *upstreamOIDCIdentityProviderBuilder {
delete(u.idToken, claim)
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeError(err error) *upstreamOIDCIdentityProviderBuilder {
u.authcodeExchangeErr = err
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCIdentityProvider {
return testutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
UsernameClaim: u.usernameClaim,
GroupsClaim: u.groupsClaim,
Scopes: []string{"scope1", "scope2"},
ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) {
return oidcclient.Token{}, u.idToken, u.authcodeExchangeErr
},
}
}
func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values {
copied := url.Values{}
for key, value := range query {