Skip to content
Merged
209 changes: 209 additions & 0 deletions policy/compiler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Package policy provides an extensible parser and compiler for composing
// a graph of CEL expressions into a single evaluable expression.
package policy

import (
"fmt"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
)

type compiler struct {
env *cel.Env
info *ast.SourceInfo
src *Source
}

type compiledRule struct {
variables []*compiledVariable
matches []*compiledMatch
}

type compiledVariable struct {
name string
expr *cel.Ast
}

type compiledMatch struct {
cond *cel.Ast
output *cel.Ast
nestedRule *compiledRule
}

// Compile generates a single CEL AST from a collection of policy expressions associated with a CEL environment.
func Compile(env *cel.Env, p *Policy) (*cel.Ast, *cel.Issues) {
c := &compiler{
env: env,
info: p.SourceInfo(),
src: p.Source(),
}
errs := common.NewErrors(c.src)
iss := cel.NewIssuesWithSourceInfo(errs, c.info)
rule, ruleIss := c.compileRule(p.Rule(), c.env, iss)
iss = iss.Append(ruleIss)
if iss.Err() != nil {
return nil, iss
}
ruleRoot, _ := env.Compile("true")
opt := cel.NewStaticOptimizer(&ruleComposer{rule: rule})
ruleExprAST, optIss := opt.Optimize(env, ruleRoot)
return ruleExprAST, iss.Append(optIss)
}

func (c *compiler) compileRule(r *Rule, ruleEnv *cel.Env, iss *cel.Issues) (*compiledRule, *cel.Issues) {
var err error
compiledVars := make([]*compiledVariable, len(r.Variables()))
for i, v := range r.Variables() {
exprSrc := c.relSource(v.Expression())
varAST, exprIss := ruleEnv.CompileSource(exprSrc)
if exprIss.Err() == nil {
ruleEnv, err = ruleEnv.Extend(cel.Variable(fmt.Sprintf("%s.%s", variablePrefix, v.Name().Value), varAST.OutputType()))
if err != nil {
iss.ReportErrorAtID(v.Expression().ID, "invalid variable declaration")
}
compiledVars[i] = &compiledVariable{
name: v.name.Value,
expr: varAST,
}
}
iss = iss.Append(exprIss)
}
compiledMatches := []*compiledMatch{}
for _, m := range r.Matches() {
condSrc := c.relSource(m.Condition())
condAST, condIss := ruleEnv.CompileSource(condSrc)
iss = iss.Append(condIss)
// This case cannot happen when the Policy object is parsed from yaml, but could happen
// with a non-YAML generation of the Policy object.
// TODO: Test this case once there's an alternative method of constructing Policy objects
if m.HasOutput() && m.HasRule() {
iss.ReportErrorAtID(m.Condition().ID, "either output or rule may be set but not both")
continue
}
if m.HasOutput() {
outSrc := c.relSource(m.Output())
outAST, outIss := ruleEnv.CompileSource(outSrc)
iss = iss.Append(outIss)
compiledMatches = append(compiledMatches, &compiledMatch{
cond: condAST,
output: outAST,
})
continue
}
if m.HasRule() {
nestedRule, ruleIss := c.compileRule(m.Rule(), ruleEnv, iss)
iss = iss.Append(ruleIss)
compiledMatches = append(compiledMatches, &compiledMatch{
cond: condAST,
nestedRule: nestedRule,
})
}
}
return &compiledRule{
variables: compiledVars,
matches: compiledMatches,
}, iss
}

func (c *compiler) relSource(pstr ValueString) *RelativeSource {
line := 0
col := 1
if offset, found := c.info.GetOffsetRange(pstr.ID); found {
if loc, found := c.src.OffsetLocation(offset.Start); found {
line = loc.Line()
col = loc.Column()
}
}
return c.src.Relative(pstr.Value, line, col)
}

type ruleComposer struct {
rule *compiledRule
}

// Optimize implements an AST optimizer for CEL which composes an expression graph into a single
// expression value.
func (opt *ruleComposer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
// The input to optimize is a dummy expression which is completely replaced according
// to the configuration of the rule composition graph.
ruleExpr, _ := optimizeRule(ctx, opt.rule)
return ctx.NewAST(ruleExpr)
}

func optimizeRule(ctx *cel.OptimizerContext, r *compiledRule) (ast.Expr, bool) {
matchExpr := ctx.NewCall("optional.none")
matches := r.matches
optionalResult := true
for i := len(matches) - 1; i >= 0; i-- {
m := matches[i]
cond := ctx.CopyASTAndMetadata(m.cond.NativeRep())
triviallyTrue := cond.Kind() == ast.LiteralKind && cond.AsLiteral() == types.True
if m.output != nil {
out := ctx.CopyASTAndMetadata(m.output.NativeRep())
if triviallyTrue {
matchExpr = out
optionalResult = false
continue
}
if optionalResult {
out = ctx.NewCall("optional.of", out)
}
matchExpr = ctx.NewCall(
operators.Conditional,
cond,
out,
matchExpr)
continue
}
nestedRule, nestedOptional := optimizeRule(ctx, m.nestedRule)
if optionalResult && !nestedOptional {
nestedRule = ctx.NewCall("optional.of", nestedRule)
}
if !optionalResult && nestedOptional {
matchExpr = ctx.NewCall("optional.of", matchExpr)
optionalResult = true
}
if !optionalResult && !nestedOptional {
ctx.ReportErrorAtID(nestedRule.ID(), "subrule early terminates policy")
continue
}
matchExpr = ctx.NewMemberCall("or", nestedRule, matchExpr)
}

vars := r.variables
for i := len(vars) - 1; i >= 0; i-- {
v := vars[i]
varAST := ctx.CopyASTAndMetadata(v.expr.NativeRep())
// Build up the bindings in reverse order, starting from root, all the way up to the outermost
// binding:
// currExpr = cel.bind(outerVar, outerExpr, currExpr)
varName := fmt.Sprintf("%s.%s", variablePrefix, v.name)
inlined, bindMacro := ctx.NewBindMacro(matchExpr.ID(), varName, varAST, matchExpr)
ctx.UpdateExpr(matchExpr, inlined)
ctx.SetMacroCall(matchExpr.ID(), bindMacro)
}
return matchExpr, optionalResult
}

const (
// Consider making the variables namespace configurable.
variablePrefix = "variables"
)
168 changes: 168 additions & 0 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package policy

import (
"fmt"
"strings"
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
)

func TestCompile(t *testing.T) {
for _, tst := range policyTests {
r := newRunner(t, tst.name, tst.expr, tst.envOpts...)
r.run(t)
}
}

func BenchmarkCompile(b *testing.B) {
for _, tst := range policyTests {
r := newRunner(b, tst.name, tst.expr, tst.envOpts...)
r.bench(b)
}
}

func newRunner(t testing.TB, name, expr string, opts ...cel.EnvOption) *runner {
r := &runner{name: name, envOptions: opts, expr: expr}
r.setup(t)
return r
}

type runner struct {
name string
envOptions []cel.EnvOption
env *cel.Env
expr string
prg cel.Program
}

func (r *runner) setup(t testing.TB) {
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", r.name))
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", r.name))
parser, err := NewParser()
if err != nil {
t.Fatalf("NewParser() failed: %v", err)
}
policy, iss := parser.Parse(srcFile)
if iss.Err() != nil {
t.Fatalf("Parse() failed: %v", iss.Err())
}
if policy.name.Value != r.name {
t.Errorf("policy name is %v, wanted %s", policy.name, r.name)
}
env, err := cel.NewEnv(
cel.OptionalTypes(),
cel.EnableMacroCallTracking(),
cel.ExtendedValidations())
if err != nil {
t.Fatalf("cel.NewEnv() failed: %v", err)
}
// Configure declarations
configOpts, err := config.AsEnvOptions(env)
if err != nil {
t.Fatalf("config.AsEnvOptions() failed: %v", err)
}
env, err = env.Extend(configOpts...)
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
// Configure any implementations
env, err = env.Extend(r.envOptions...)
if err != nil {
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
}
ast, iss := Compile(env, policy)
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
pExpr, err := cel.AstToString(ast)
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
if r.expr != "" && normalize(pExpr) != normalize(r.expr) {
t.Errorf("cel.AstToString() got %s, wanted %s", pExpr, r.expr)
}
prg, err := env.Program(ast, cel.EvalOptions(cel.OptOptimize))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
r.env = env
r.prg = prg
}

func (r *runner) run(t *testing.T) {
tests := readTestSuite(t, fmt.Sprintf("testdata/%s/tests.yaml", r.name))
for _, s := range tests.Sections {
section := s.Name
for _, tst := range s.Tests {
tc := tst
t.Run(fmt.Sprintf("%s/%s/%s", r.name, section, tc.Name), func(t *testing.T) {
out, _, err := r.prg.Eval(tc.Input)
if err != nil {
t.Fatalf("prg.Eval(tc.Input) failed: %v", err)
}
wantExpr, iss := r.env.Compile(tc.Output)
if iss.Err() != nil {
t.Fatalf("env.Compile(%q) failed :%v", tc.Output, iss.Err())
}
testPrg, err := r.env.Program(wantExpr)
if err != nil {
t.Fatalf("env.Program(wantExpr) failed: %v", err)
}
testOut, _, err := testPrg.Eval(cel.NoVars())
if err != nil {
t.Fatalf("testPrg.Eval() failed: %v", err)
}
if optOut, ok := out.(*types.Optional); ok {
if optOut.Equal(types.OptionalNone) == types.True {
if testOut.Equal(types.OptionalNone) != types.True {
t.Errorf("policy eval got %v, wanted %v", out, testOut)
}
} else if testOut.Equal(optOut.GetValue()) != types.True {
t.Errorf("policy eval got %v, wanted %v", out, testOut)
}
}
})
}
}
}

func (r *runner) bench(b *testing.B) {
tests := readTestSuite(b, fmt.Sprintf("testdata/%s/tests.yaml", r.name))
for _, s := range tests.Sections {
section := s.Name
for _, tst := range s.Tests {
tc := tst
b.Run(fmt.Sprintf("%s/%s/%s", r.name, section, tc.Name), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, err := r.prg.Eval(tc.Input)
if err != nil {
b.Fatalf("policy eval failed: %v", err)
}
}
})
}
}
}

func normalize(s string) string {
return strings.ReplaceAll(
strings.ReplaceAll(
strings.ReplaceAll(s, " ", ""), "\n", ""),
"\t", "")
}
Loading