diff --git a/internal/celtransformer/celformer.go b/internal/celtransformer/celformer.go index 409e064b..133555a1 100644 --- a/internal/celtransformer/celformer.go +++ b/internal/celtransformer/celformer.go @@ -12,6 +12,7 @@ import ( "context" "fmt" "reflect" + "regexp" "strings" "time" @@ -23,8 +24,10 @@ import ( ) const ( - usernameVariableName = "username" - groupsVariableName = "groups" + usernameVariableName = "username" + groupsVariableName = "groups" + constStringVariableName = "strConst" + constStringListVariableName = "strListConst" defaultPolicyRejectedAuthMessage = "Authentication was rejected by a configured policy" ) @@ -46,15 +49,51 @@ func NewCELTransformer(maxExpressionRuntime time.Duration) (*CELTransformer, err return &CELTransformer{compiler: env, maxExpressionRuntime: maxExpressionRuntime}, nil } +// TransformationConstants can be used to make more variables available to compiled CEL expressions for convenience. +type TransformationConstants struct { + // A map of variable names to their string values. If a key "x" has value "123", then it will be available + // to CEL expressions as the variable `strConst.x` with value `"123"`. + StringConstants map[string]string + // A map of variable names to their string list values. If a key "x" has value []string{"123","456"}, + // then it will be available to CEL expressions as the variable `strListConst.x` with value `["123","456"]`. + StringListConstants map[string][]string +} + +// Valid identifiers in CEL expressions are defined by the CEL language spec as: [_a-zA-Z][_a-zA-Z0-9]* +var validIdentifiersRegexp = regexp.MustCompile(`^[_a-zA-Z][_a-zA-Z0-9]*$`) + +func (t *TransformationConstants) validateVariableNames() error { + const errFormat = "%q is an invalid const variable name (must match [_a-zA-Z][_a-zA-Z0-9]*)" + for k := range t.StringConstants { + if !validIdentifiersRegexp.MatchString(k) { + return fmt.Errorf(errFormat, k) + } + } + for k := range t.StringListConstants { + if !validIdentifiersRegexp.MatchString(k) { + return fmt.Errorf(errFormat, k) + } + } + return nil +} + // CompileTransformation compiles a CEL-based identity transformation expression. // The compiled transform can be cached in memory and executed repeatedly and in a thread-safe way. -func (c *CELTransformer) CompileTransformation(t CELTransformation) (idtransform.IdentityTransformation, error) { - return t.compile(c) +// The caller must not modify the consts param struct after calling this function to allow +// the returned IdentityTransformation to use it as a thread-safe read-only structure. +func (c *CELTransformer) CompileTransformation(t CELTransformation, consts *TransformationConstants) (idtransform.IdentityTransformation, error) { + if consts == nil { + consts = &TransformationConstants{} + } + if err := consts.validateVariableNames(); err != nil { + return nil, err + } + return t.compile(c, consts) } // CELTransformation can be compiled into an IdentityTransformation. type CELTransformation interface { - compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) + compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) } // UsernameTransformation is a CEL expression that can transform a username (or leave it unchanged). @@ -108,76 +147,91 @@ func compileProgram(transformer *CELTransformer, expectedExpressionType *cel.Typ return program, nil } -func (t *UsernameTransformation) compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) { +func (t *UsernameTransformation) compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) { program, err := compileProgram(transformer, cel.StringType, t.Expression) if err != nil { return nil, err } return &compiledUsernameTransformation{ - program: program, - maxExpressionRuntime: transformer.maxExpressionRuntime, + baseCompiledTransformation: &baseCompiledTransformation{ + program: program, + consts: consts, + maxExpressionRuntime: transformer.maxExpressionRuntime, + }, }, nil } -func (t *GroupsTransformation) compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) { +func (t *GroupsTransformation) compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) { program, err := compileProgram(transformer, cel.ListType(cel.StringType), t.Expression) if err != nil { return nil, err } return &compiledGroupsTransformation{ - program: program, - maxExpressionRuntime: transformer.maxExpressionRuntime, + baseCompiledTransformation: &baseCompiledTransformation{ + program: program, + consts: consts, + maxExpressionRuntime: transformer.maxExpressionRuntime, + }, }, nil } -func (t *AllowAuthenticationPolicy) compile(transformer *CELTransformer) (idtransform.IdentityTransformation, error) { +func (t *AllowAuthenticationPolicy) compile(transformer *CELTransformer, consts *TransformationConstants) (idtransform.IdentityTransformation, error) { program, err := compileProgram(transformer, cel.BoolType, t.Expression) if err != nil { return nil, err } return &compiledAllowAuthenticationPolicy{ - program: program, - maxExpressionRuntime: transformer.maxExpressionRuntime, + baseCompiledTransformation: &baseCompiledTransformation{ + program: program, + consts: consts, + maxExpressionRuntime: transformer.maxExpressionRuntime, + }, rejectedAuthenticationMessage: t.RejectedAuthenticationMessage, }, nil } +// Base type for common aspects of compiled transformations. +type baseCompiledTransformation struct { + program cel.Program + consts *TransformationConstants + maxExpressionRuntime time.Duration +} + // Implements idtransform.IdentityTransformation. type compiledUsernameTransformation struct { - program cel.Program - maxExpressionRuntime time.Duration + *baseCompiledTransformation } // Implements idtransform.IdentityTransformation. type compiledGroupsTransformation struct { - program cel.Program - maxExpressionRuntime time.Duration + *baseCompiledTransformation } // Implements idtransform.IdentityTransformation. type compiledAllowAuthenticationPolicy struct { - program cel.Program - maxExpressionRuntime time.Duration + *baseCompiledTransformation rejectedAuthenticationMessage string } -func evalProgram(ctx context.Context, program cel.Program, maxExpressionRuntime time.Duration, username string, groups []string) (ref.Val, error) { +func (c *baseCompiledTransformation) evalProgram(ctx context.Context, username string, groups []string) (ref.Val, error) { // Limit the runtime of a CEL expression to avoid accidental very expensive expressions. - timeoutCtx, cancel := context.WithTimeout(ctx, maxExpressionRuntime) + timeoutCtx, cancel := context.WithTimeout(ctx, c.maxExpressionRuntime) defer cancel() // Evaluation is thread-safe and side effect free. Many inputs can be sent to the same cel.Program // and if fields are present in the input, but not referenced in the expression, they are ignored. // The argument to Eval may either be an `interpreter.Activation` or a `map[string]interface{}`. - val, _, err := program.ContextEval(timeoutCtx, map[string]interface{}{ - usernameVariableName: username, - groupsVariableName: groups, + val, _, err := c.program.ContextEval(timeoutCtx, map[string]interface{}{ + usernameVariableName: username, + groupsVariableName: groups, + constStringVariableName: c.consts.StringConstants, + constStringListVariableName: c.consts.StringListConstants, }) return val, err } func (c *compiledUsernameTransformation) Evaluate(ctx context.Context, username string, groups []string) (*idtransform.TransformationResult, error) { - val, err := evalProgram(ctx, c.program, c.maxExpressionRuntime, username, groups) + val, err := c.evalProgram(ctx, username, groups) if err != nil { return nil, err } @@ -197,7 +251,7 @@ func (c *compiledUsernameTransformation) Evaluate(ctx context.Context, username } func (c *compiledGroupsTransformation) Evaluate(ctx context.Context, username string, groups []string) (*idtransform.TransformationResult, error) { - val, err := evalProgram(ctx, c.program, c.maxExpressionRuntime, username, groups) + val, err := c.evalProgram(ctx, username, groups) if err != nil { return nil, err } @@ -217,7 +271,7 @@ func (c *compiledGroupsTransformation) Evaluate(ctx context.Context, username st } func (c *compiledAllowAuthenticationPolicy) Evaluate(ctx context.Context, username string, groups []string) (*idtransform.TransformationResult, error) { - val, err := evalProgram(ctx, c.program, c.maxExpressionRuntime, username, groups) + val, err := c.evalProgram(ctx, username, groups) if err != nil { return nil, err } @@ -254,6 +308,8 @@ func newEnv() (*cel.Env, error) { // the parsing/checking phase. cel.Variable(usernameVariableName, cel.StringType), cel.Variable(groupsVariableName, cel.ListType(cel.StringType)), + cel.Variable(constStringVariableName, cel.MapType(cel.StringType, cel.StringType)), + cel.Variable(constStringListVariableName, cel.MapType(cel.StringType, cel.ListType(cel.StringType))), // Enable the strings extensions. // See https://github.com/google/cel-go/tree/master/ext#strings diff --git a/internal/celtransformer/celformer_test.go b/internal/celtransformer/celformer_test.go index 762f8ecf..54724a41 100644 --- a/internal/celtransformer/celformer_test.go +++ b/internal/celtransformer/celformer_test.go @@ -30,6 +30,7 @@ func TestTransformer(t *testing.T) { username string groups []string transforms []CELTransformation + consts *TransformationConstants ctx context.Context wantUsername string @@ -113,6 +114,28 @@ func TestTransformer(t *testing.T) { wantUsername: "other", wantGroups: []string{"admins", "developers", "other", "ryan", "other2"}, }, + { + name: "any transformation can use the provided constants as variables", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + consts: &TransformationConstants{ + StringConstants: map[string]string{ + "x": "abc", + "y": "def", + }, + StringListConstants: map[string][]string{ + "x": {"uvw", "xyz"}, + "y": {"123", "456"}, + }, + }, + transforms: []CELTransformation{ + &UsernameTransformation{Expression: `strConst.x + strListConst.x[0]`}, + &GroupsTransformation{Expression: `[strConst.x, strConst.y, strListConst.x[1], strListConst.y[0]]`}, + &AllowAuthenticationPolicy{Expression: `strConst.x == "abc"`}, + }, + wantUsername: "abcuvw", + wantGroups: []string{"abc", "def", "xyz", "123"}, + }, { name: "the CEL string extensions are enabled for use in the expressions", username: " ryan ", @@ -219,6 +242,19 @@ func TestTransformer(t *testing.T) { wantUsername: "ryan", wantGroups: []string{"admins", "developers"}, }, + { + name: "can filter groups based on an allow list provided as a const", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + consts: &TransformationConstants{ + StringListConstants: map[string][]string{"allowedGroups": {"admins", "developers"}}, + }, + transforms: []CELTransformation{ + &GroupsTransformation{Expression: `groups.filter(g, g in strListConst.allowedGroups)`}, + }, + wantUsername: "ryan", + wantGroups: []string{"admins", "developers"}, + }, { name: "can filter groups based on a disallow list", username: "ryan", @@ -239,6 +275,19 @@ func TestTransformer(t *testing.T) { wantUsername: "ryan", wantGroups: []string{"other"}, }, + { + name: "can filter groups based on a disallowed prefixes provided as a const", + username: "ryan", + groups: []string{"disallowed1:admins", "disallowed2:developers", "other"}, + consts: &TransformationConstants{ + StringListConstants: map[string][]string{"disallowedPrefixes": {"disallowed1:", "disallowed2:"}}, + }, + transforms: []CELTransformation{ + &GroupsTransformation{Expression: `groups.filter(group, !(strListConst.disallowedPrefixes.exists(prefix, group.startsWith(prefix))))`}, + }, + wantUsername: "ryan", + wantGroups: []string{"other"}, + }, { name: "can add a group", username: "ryan", @@ -249,6 +298,19 @@ func TestTransformer(t *testing.T) { wantUsername: "ryan", wantGroups: []string{"admins", "developers", "other", "new-group"}, }, + { + name: "can add a group from a const", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + consts: &TransformationConstants{ + StringConstants: map[string]string{"groupToAlwaysAdd": "new-group"}, + }, + transforms: []CELTransformation{ + &GroupsTransformation{Expression: `groups + [strConst.groupToAlwaysAdd]`}, + }, + wantUsername: "ryan", + wantGroups: []string{"admins", "developers", "other", "new-group"}, + }, { name: "can add a group but only if they already belong to another group - when the user does belong to that other group", username: "ryan", @@ -622,6 +684,44 @@ func TestTransformer(t *testing.T) { }, wantCompileErr: `CEL expression should return type "list(string)" but returns type "list(dyn)"`, }, + { + name: "using string constants which were not were provided", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + transforms: []CELTransformation{ + &UsernameTransformation{Expression: `strConst.x`}, + }, + wantEvaluationErr: `identity transformation at index 0: no such key: x`, + }, + { + name: "using string list constants which were not were provided", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + transforms: []CELTransformation{ + &GroupsTransformation{Expression: `strListConst.x`}, + }, + wantEvaluationErr: `identity transformation at index 0: no such key: x`, + }, + { + name: "using an illegal name for a string constant", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + consts: &TransformationConstants{StringConstants: map[string]string{" illegal": "a"}}, + transforms: []CELTransformation{ + &UsernameTransformation{Expression: `username`}, + }, + wantCompileErr: `" illegal" is an invalid const variable name (must match [_a-zA-Z][_a-zA-Z0-9]*)`, + }, + { + name: "using an illegal name for a stringList constant", + username: "ryan", + groups: []string{"admins", "developers", "other"}, + consts: &TransformationConstants{StringListConstants: map[string][]string{" illegal": {"a"}}}, + transforms: []CELTransformation{ + &UsernameTransformation{Expression: `username`}, + }, + wantCompileErr: `" illegal" is an invalid const variable name (must match [_a-zA-Z][_a-zA-Z0-9]*)`, + }, } for _, tt := range tests { @@ -635,7 +735,7 @@ func TestTransformer(t *testing.T) { pipeline := idtransform.NewTransformationPipeline() for _, transform := range tt.transforms { - compiledTransform, err := transformer.CompileTransformation(transform) + compiledTransform, err := transformer.CompileTransformation(transform, tt.consts) if tt.wantCompileErr != "" { require.EqualError(t, err, tt.wantCompileErr) return // the rest of the test doesn't make sense when there was a compile error @@ -673,13 +773,13 @@ func TestTypicalPerformanceAndThreadSafety(t *testing.T) { pipeline := idtransform.NewTransformationPipeline() var compiledTransform idtransform.IdentityTransformation - compiledTransform, err = transformer.CompileTransformation(&UsernameTransformation{Expression: `"username_prefix:" + username`}) + compiledTransform, err = transformer.CompileTransformation(&UsernameTransformation{Expression: `"username_prefix:" + username`}, nil) require.NoError(t, err) pipeline.AppendTransformation(compiledTransform) - compiledTransform, err = transformer.CompileTransformation(&GroupsTransformation{Expression: `groups.map(g, "group_prefix:" + g)`}) + compiledTransform, err = transformer.CompileTransformation(&GroupsTransformation{Expression: `groups.map(g, "group_prefix:" + g)`}, nil) require.NoError(t, err) pipeline.AppendTransformation(compiledTransform) - compiledTransform, err = transformer.CompileTransformation(&AllowAuthenticationPolicy{Expression: `username == "username_prefix:ryan"`}) + compiledTransform, err = transformer.CompileTransformation(&AllowAuthenticationPolicy{Expression: `username == "username_prefix:ryan"`}, nil) require.NoError(t, err) pipeline.AppendTransformation(compiledTransform)