diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 985ade31..93073cd6 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -21,6 +21,21 @@ import ( "go.pinniped.dev/internal/plog" ) +const ( + // defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC + // ID token if the upstream OIDC IDP did not tell us to use another claim. + defaultUpstreamUsernameClaim = "sub" + + // defaultUpstreamGroupsClaim is what we will use to extract the groups from an upstream OIDC ID + // token if the upstream OIDC IDP did not tell us to use another claim. + defaultUpstreamGroupsClaim = "groups" + + // downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token + // information. + // TODO: should this be per-issuer? Or per version? + downstreamGroupsClaim = "oidc.pinniped.dev/groups" +) + func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) @@ -61,14 +76,32 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi } var username string - // TODO handle the case when upstreamIDPConfig.GetUsernameClaim() is the empty string by defaulting to something reasonable - usernameAsInterface := idTokenClaims[upstreamIDPConfig.GetUsernameClaim()] - username, ok := usernameAsInterface.(string) + usernameClaim := upstreamIDPConfig.GetUsernameClaim() + if usernameClaim == "" { + usernameClaim = defaultUpstreamUsernameClaim + } + usernameAsInterface, ok := idTokenClaims[usernameClaim] + if !ok { + panic(err) // TODO + } + username, ok = usernameAsInterface.(string) if !ok { panic(err) // TODO } - // TODO also look at the upstream ID token's groups claim and store that value as a downstream ID token claim + var groups []string + groupsClaim := upstreamIDPConfig.GetGroupsClaim() + if groupsClaim == "" { + groupsClaim = defaultUpstreamGroupsClaim + } + groupsAsInterface, ok := idTokenClaims[groupsClaim] + if !ok { + panic(err) // TODO + } + groups, ok = groupsAsInterface.([]string) + if !ok { + panic(err) // TODO + } now := time.Now() authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &openid.DefaultSession{ @@ -80,6 +113,9 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi IssuedAt: now, // TODO test this RequestedAt: now, // TODO test this AuthTime: now, // TODO test this + Extra: map[string]interface{}{ + downstreamGroupsClaim: groups, + }, }, }) if err != nil { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index b6585a3c..a375f095 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -47,8 +47,24 @@ func TestCallbackEndpoint(t *testing.T) { ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { return oidcclient.Token{}, map[string]interface{}{ - "the-user-claim": "test-pinniped-username", - "other-claim": "should be ignored", + "the-user-claim": "test-pinniped-username", + "the-groups-claim": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, + "other-claim": "should be ignored", + }, + nil + }, + } + + defaultClaimsUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + Scopes: []string{"scope1", "scope2"}, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, + map[string]interface{}{ + "sub": "test-pinniped-username", + "groups": []string{"test-pinniped-group-0", "test-pinniped-group-1"}, + "other-claim": "should be ignored", }, nil }, @@ -177,6 +193,9 @@ func TestCallbackEndpoint(t *testing.T) { ExpectedIDTokenNonce: nonce.Nonce(happyNonce), } + // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it + happyRedirectLocationRegexp := downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState + tests := []struct { name string @@ -195,14 +214,26 @@ func TestCallbackEndpoint(t *testing.T) { wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ { - name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idp: upstreamOIDCIdentityProvider, - method: http.MethodGet, - path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, - // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyRedirectLocationRegexp, + wantAuthcodeStored: true, + wantGrantedOpenidScope: true, + wantBody: "", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, + }, + { + name: "upstream IDP uses default claims", + idp: defaultClaimsUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: happyRedirectLocationRegexp, wantAuthcodeStored: true, wantGrantedOpenidScope: true, wantBody: "", @@ -418,6 +449,7 @@ func TestCallbackEndpoint(t *testing.T) { require.NotContains(t, storedRequest.GetGrantedScopes(), "openid") } require.Equal(t, "test-pinniped-username", storedSession.Claims.Subject) + require.Equal(t, []string{"test-pinniped-group-0", "test-pinniped-group-1"}, storedSession.Claims.Extra["oidc.pinniped.dev/groups"]) } else { require.Empty(t, rsp.Header().Values("Location")) }