Allow user-defined string & stringList consts for use in CEL expressions

This commit is contained in:
Ryan Richard 2023-02-06 17:04:59 -08:00
parent 5385fb38db
commit 1a53b4daea
2 changed files with 188 additions and 32 deletions

View File

@ -12,6 +12,7 @@ import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
"time"
@ -23,8 +24,10 @@ import (
)
const (
usernameVariableName = "username"
groupsVariableName = "groups"
usernameVariableName = "username"
groupsVariableName = "groups"
constStringVariableName = "strConst"
constStringListVariableName = "strListConst"
defaultPolicyRejectedAuthMessage = "Authentication was rejected by a configured policy"
)
@ -46,15 +49,51 @@ func NewCELTransformer(maxExpressionRuntime time.Duration) (*CELTransformer, err
return &CELTransformer{compiler: env, maxExpressionRuntime: maxExpressionRuntime}, nil
}
// TransformationConstants can be used to make more variables available to compiled CEL expressions for convenience.
type TransformationConstants struct {
// A map of variable names to their string values. If a key "x" has value "123", then it will be available
// to CEL expressions as the variable `strConst.x` with value `"123"`.
StringConstants map[string]string
// A map of variable names to their string list values. If a key "x" has value []string{"123","456"},
// then it will be available to CEL expressions as the variable `strListConst.x` with value `["123","456"]`.
StringListConstants map[string][]string
}
// Valid identifiers in CEL expressions are defined by the CEL language spec as: [_a-zA-Z][_a-zA-Z0-9]*
var validIdentifiersRegexp = regexp.MustCompile(`^[_a-zA-Z][_a-zA-Z0-9]*$`)
func (t *TransformationConstants) validateVariableNames() error {
const errFormat = "%q is an invalid const variable name (must match [_a-zA-Z][_a-zA-Z0-9]*)"
for k := range t.StringConstants {
if !validIdentifiersRegexp.MatchString(k) {
return fmt.Errorf(errFormat, k)
}
}
for k := range t.StringListConstants {
if !validIdentifiersRegexp.MatchString(k) {
return fmt.Errorf(errFormat, k)
}
}
return nil
}
// CompileTransformation compiles a CEL-based identity transformation expression.
// The compiled transform can be cached in memory and executed repeatedly and in a thread-safe way.
func (c *CELTransformer) CompileTransformation(t CELTransformation) (idtransform.IdentityTransformation, error) {
return t.compile(c)
// The caller must not modify the consts param struct after calling this function to allow
// the returned IdentityTransformation to use it as a thread-safe read-only structure.
func (c *CELTransformer) CompileTransformation(t CELTransformation, consts *TransformationConstants) (idtransform.IdentityTransformation, error) {
if consts == nil {
consts = &TransformationConstants{}
}
if err := consts.validateVariableNames(); err != nil {
return nil, err
}
return t.compile(c, consts)
}
// CELTransformation can be compiled into an IdentityTransformation.
type CELTransformation interface {
compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error)
compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error)
}
// UsernameTransformation is a CEL expression that can transform a username (or leave it unchanged).
@ -108,76 +147,91 @@ func compileProgram(transformer *CELTransformer, expectedExpressionType *cel.Typ
return program, nil
}
func (t *UsernameTransformation) compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) {
func (t *UsernameTransformation) compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) {
program, err := compileProgram(transformer, cel.StringType, t.Expression)
if err != nil {
return nil, err
}
return &compiledUsernameTransformation{
program: program,
maxExpressionRuntime: transformer.maxExpressionRuntime,
baseCompiledTransformation: &baseCompiledTransformation{
program: program,
consts: consts,
maxExpressionRuntime: transformer.maxExpressionRuntime,
},
}, nil
}
func (t *GroupsTransformation) compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) {
func (t *GroupsTransformation) compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) {
program, err := compileProgram(transformer, cel.ListType(cel.StringType), t.Expression)
if err != nil {
return nil, err
}
return &compiledGroupsTransformation{
program: program,
maxExpressionRuntime: transformer.maxExpressionRuntime,
baseCompiledTransformation: &baseCompiledTransformation{
program: program,
consts: consts,
maxExpressionRuntime: transformer.maxExpressionRuntime,
},
}, nil
}
func (t *AllowAuthenticationPolicy) compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) {
func (t *AllowAuthenticationPolicy) compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) {
program, err := compileProgram(transformer, cel.BoolType, t.Expression)
if err != nil {
return nil, err
}
return &compiledAllowAuthenticationPolicy{
program: program,
maxExpressionRuntime: transformer.maxExpressionRuntime,
baseCompiledTransformation: &baseCompiledTransformation{
program: program,
consts: consts,
maxExpressionRuntime: transformer.maxExpressionRuntime,
},
rejectedAuthenticationMessage: t.RejectedAuthenticationMessage,
}, nil
}
// Base type for common aspects of compiled transformations.
type baseCompiledTransformation struct {
program cel.Program
consts *TransformationConstants
maxExpressionRuntime time.Duration
}
// Implements idtransform.IdentityTransformation.
type compiledUsernameTransformation struct {
program cel.Program
maxExpressionRuntime time.Duration
*baseCompiledTransformation
}
// Implements idtransform.IdentityTransformation.
type compiledGroupsTransformation struct {
program cel.Program
maxExpressionRuntime time.Duration
*baseCompiledTransformation
}
// Implements idtransform.IdentityTransformation.
type compiledAllowAuthenticationPolicy struct {
program cel.Program
maxExpressionRuntime time.Duration
*baseCompiledTransformation
rejectedAuthenticationMessage string
}
func evalProgram(ctx context.Context, program cel.Program, maxExpressionRuntime time.Duration, username string, groups []string) (ref.Val, error) {
func (c *baseCompiledTransformation) evalProgram(ctx context.Context, username string, groups []string) (ref.Val, error) {
// Limit the runtime of a CEL expression to avoid accidental very expensive expressions.
timeoutCtx, cancel := context.WithTimeout(ctx, maxExpressionRuntime)
timeoutCtx, cancel := context.WithTimeout(ctx, c.maxExpressionRuntime)
defer cancel()
// Evaluation is thread-safe and side effect free. Many inputs can be sent to the same cel.Program
// and if fields are present in the input, but not referenced in the expression, they are ignored.
// The argument to Eval may either be an `interpreter.Activation` or a `map[string]interface{}`.
val, _, err := program.ContextEval(timeoutCtx, map[string]interface{}{
usernameVariableName: username,
groupsVariableName: groups,
val, _, err := c.program.ContextEval(timeoutCtx, map[string]interface{}{
usernameVariableName: username,
groupsVariableName: groups,
constStringVariableName: c.consts.StringConstants,
constStringListVariableName: c.consts.StringListConstants,
})
return val, err
}
func (c *compiledUsernameTransformation) Evaluate(ctx context.Context, username string, groups []string) (*idtransform.TransformationResult, error) {
val, err := evalProgram(ctx, c.program, c.maxExpressionRuntime, username, groups)
val, err := c.evalProgram(ctx, username, groups)
if err != nil {
return nil, err
}
@ -197,7 +251,7 @@ func (c *compiledUsernameTransformation) Evaluate(ctx context.Context, username
}
func (c *compiledGroupsTransformation) Evaluate(ctx context.Context, username string, groups []string) (*idtransform.TransformationResult, error) {
val, err := evalProgram(ctx, c.program, c.maxExpressionRuntime, username, groups)
val, err := c.evalProgram(ctx, username, groups)
if err != nil {
return nil, err
}
@ -217,7 +271,7 @@ func (c *compiledGroupsTransformation) Evaluate(ctx context.Context, username st
}
func (c *compiledAllowAuthenticationPolicy) Evaluate(ctx context.Context, username string, groups []string) (*idtransform.TransformationResult, error) {
val, err := evalProgram(ctx, c.program, c.maxExpressionRuntime, username, groups)
val, err := c.evalProgram(ctx, username, groups)
if err != nil {
return nil, err
}
@ -254,6 +308,8 @@ func newEnv() (*cel.Env, error) {
// the parsing/checking phase.
cel.Variable(usernameVariableName, cel.StringType),
cel.Variable(groupsVariableName, cel.ListType(cel.StringType)),
cel.Variable(constStringVariableName, cel.MapType(cel.StringType, cel.StringType)),
cel.Variable(constStringListVariableName, cel.MapType(cel.StringType, cel.ListType(cel.StringType))),
// Enable the strings extensions.
// See https://github.com/google/cel-go/tree/master/ext#strings

View File

@ -30,6 +30,7 @@ func TestTransformer(t *testing.T) {
username string
groups []string
transforms []CELTransformation
consts *TransformationConstants
ctx context.Context
wantUsername string
@ -113,6 +114,28 @@ func TestTransformer(t *testing.T) {
wantUsername: "other",
wantGroups: []string{"admins", "developers", "other", "ryan", "other2"},
},
{
name: "any transformation can use the provided constants as variables",
username: "ryan",
groups: []string{"admins", "developers", "other"},
consts: &TransformationConstants{
StringConstants: map[string]string{
"x": "abc",
"y": "def",
},
StringListConstants: map[string][]string{
"x": {"uvw", "xyz"},
"y": {"123", "456"},
},
},
transforms: []CELTransformation{
&UsernameTransformation{Expression: `strConst.x + strListConst.x[0]`},
&GroupsTransformation{Expression: `[strConst.x, strConst.y, strListConst.x[1], strListConst.y[0]]`},
&AllowAuthenticationPolicy{Expression: `strConst.x == "abc"`},
},
wantUsername: "abcuvw",
wantGroups: []string{"abc", "def", "xyz", "123"},
},
{
name: "the CEL string extensions are enabled for use in the expressions",
username: " ryan ",
@ -219,6 +242,19 @@ func TestTransformer(t *testing.T) {
wantUsername: "ryan",
wantGroups: []string{"admins", "developers"},
},
{
name: "can filter groups based on an allow list provided as a const",
username: "ryan",
groups: []string{"admins", "developers", "other"},
consts: &TransformationConstants{
StringListConstants: map[string][]string{"allowedGroups": {"admins", "developers"}},
},
transforms: []CELTransformation{
&GroupsTransformation{Expression: `groups.filter(g, g in strListConst.allowedGroups)`},
},
wantUsername: "ryan",
wantGroups: []string{"admins", "developers"},
},
{
name: "can filter groups based on a disallow list",
username: "ryan",
@ -239,6 +275,19 @@ func TestTransformer(t *testing.T) {
wantUsername: "ryan",
wantGroups: []string{"other"},
},
{
name: "can filter groups based on a disallowed prefixes provided as a const",
username: "ryan",
groups: []string{"disallowed1:admins", "disallowed2:developers", "other"},
consts: &TransformationConstants{
StringListConstants: map[string][]string{"disallowedPrefixes": {"disallowed1:", "disallowed2:"}},
},
transforms: []CELTransformation{
&GroupsTransformation{Expression: `groups.filter(group, !(strListConst.disallowedPrefixes.exists(prefix, group.startsWith(prefix))))`},
},
wantUsername: "ryan",
wantGroups: []string{"other"},
},
{
name: "can add a group",
username: "ryan",
@ -249,6 +298,19 @@ func TestTransformer(t *testing.T) {
wantUsername: "ryan",
wantGroups: []string{"admins", "developers", "other", "new-group"},
},
{
name: "can add a group from a const",
username: "ryan",
groups: []string{"admins", "developers", "other"},
consts: &TransformationConstants{
StringConstants: map[string]string{"groupToAlwaysAdd": "new-group"},
},
transforms: []CELTransformation{
&GroupsTransformation{Expression: `groups + [strConst.groupToAlwaysAdd]`},
},
wantUsername: "ryan",
wantGroups: []string{"admins", "developers", "other", "new-group"},
},
{
name: "can add a group but only if they already belong to another group - when the user does belong to that other group",
username: "ryan",
@ -622,6 +684,44 @@ func TestTransformer(t *testing.T) {
},
wantCompileErr: `CEL expression should return type "list(string)" but returns type "list(dyn)"`,
},
{
name: "using string constants which were not were provided",
username: "ryan",
groups: []string{"admins", "developers", "other"},
transforms: []CELTransformation{
&UsernameTransformation{Expression: `strConst.x`},
},
wantEvaluationErr: `identity transformation at index 0: no such key: x`,
},
{
name: "using string list constants which were not were provided",
username: "ryan",
groups: []string{"admins", "developers", "other"},
transforms: []CELTransformation{
&GroupsTransformation{Expression: `strListConst.x`},
},
wantEvaluationErr: `identity transformation at index 0: no such key: x`,
},
{
name: "using an illegal name for a string constant",
username: "ryan",
groups: []string{"admins", "developers", "other"},
consts: &TransformationConstants{StringConstants: map[string]string{" illegal": "a"}},
transforms: []CELTransformation{
&UsernameTransformation{Expression: `username`},
},
wantCompileErr: `" illegal" is an invalid const variable name (must match [_a-zA-Z][_a-zA-Z0-9]*)`,
},
{
name: "using an illegal name for a stringList constant",
username: "ryan",
groups: []string{"admins", "developers", "other"},
consts: &TransformationConstants{StringListConstants: map[string][]string{" illegal": {"a"}}},
transforms: []CELTransformation{
&UsernameTransformation{Expression: `username`},
},
wantCompileErr: `" illegal" is an invalid const variable name (must match [_a-zA-Z][_a-zA-Z0-9]*)`,
},
}
for _, tt := range tests {
@ -635,7 +735,7 @@ func TestTransformer(t *testing.T) {
pipeline := idtransform.NewTransformationPipeline()
for _, transform := range tt.transforms {
compiledTransform, err := transformer.CompileTransformation(transform)
compiledTransform, err := transformer.CompileTransformation(transform, tt.consts)
if tt.wantCompileErr != "" {
require.EqualError(t, err, tt.wantCompileErr)
return // the rest of the test doesn't make sense when there was a compile error
@ -673,13 +773,13 @@ func TestTypicalPerformanceAndThreadSafety(t *testing.T) {
pipeline := idtransform.NewTransformationPipeline()
var compiledTransform idtransform.IdentityTransformation
compiledTransform, err = transformer.CompileTransformation(&UsernameTransformation{Expression: `"username_prefix:" + username`})
compiledTransform, err = transformer.CompileTransformation(&UsernameTransformation{Expression: `"username_prefix:" + username`}, nil)
require.NoError(t, err)
pipeline.AppendTransformation(compiledTransform)
compiledTransform, err = transformer.CompileTransformation(&GroupsTransformation{Expression: `groups.map(g, "group_prefix:" + g)`})
compiledTransform, err = transformer.CompileTransformation(&GroupsTransformation{Expression: `groups.map(g, "group_prefix:" + g)`}, nil)
require.NoError(t, err)
pipeline.AppendTransformation(compiledTransform)
compiledTransform, err = transformer.CompileTransformation(&AllowAuthenticationPolicy{Expression: `username == "username_prefix:ryan"`})
compiledTransform, err = transformer.CompileTransformation(&AllowAuthenticationPolicy{Expression: `username == "username_prefix:ryan"`}, nil)
require.NoError(t, err)
pipeline.AppendTransformation(compiledTransform)