Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cost tracking for two-variable comprehensions and bindings #1104

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
649 changes: 490 additions & 159 deletions checker/cost.go

Large diffs are not rendered by default.

91 changes: 76 additions & 15 deletions checker/cost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestCost(t *testing.T) {
nestedMap := types.NewMapType(types.StringType, allMap)

zeroCost := CostEstimate{}
oneCost := CostEstimate{Min: 1, Max: 1}
oneCost := FixedCostEstimate(1)
cases := []struct {
name string
expr string
Expand Down Expand Up @@ -255,6 +255,11 @@ func TestCost(t *testing.T) {
expr: `size("123")`,
wanted: oneCost,
},
{
name: "bytes size",
expr: `size(b"123")`,
wanted: oneCost,
},
{
name: "bytes to string conversion",
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
Expand Down Expand Up @@ -462,6 +467,36 @@ func TestCost(t *testing.T) {
},
wanted: CostEstimate{Min: 5, Max: 5},
},
{
name: "list size from concat",
expr: `([x, y] + list1 + list2).size()`,
vars: []*decls.VariableDecl{
decls.NewVariable("x", types.IntType),
decls.NewVariable("y", types.IntType),
decls.NewVariable("list1", types.NewListType(types.IntType)),
decls.NewVariable("list2", types.NewListType(types.IntType)),
},
hints: map[string]uint64{
"list1": 10,
"list2": 20,
},
wanted: CostEstimate{Min: 17, Max: 17},
},
{
name: "list cost tracking through comprehension",
expr: `[list1, list2].exists(l, l.exists(v, v.startsWith('hi')))`,
vars: []*decls.VariableDecl{
decls.NewVariable("list1", types.NewListType(types.StringType)),
decls.NewVariable("list2", types.NewListType(types.StringType)),
},
hints: map[string]uint64{
"list1": 10,
"list1.@items": 64,
"list2": 20,
"list2.@items": 128,
},
wanted: CostEstimate{Min: 21, Max: 265},
},
{
name: "str endsWith equality",
expr: `str1.endsWith("abcdefghijklmnopqrstuvwxyz") == str2.endsWith("abcdefghijklmnopqrstuvwxyz")`,
Expand Down Expand Up @@ -539,27 +574,37 @@ func TestCost(t *testing.T) {
wanted: CostEstimate{Min: 61, Max: 61},
},
{
name: "nested array selection",
name: "nested map selection",
expr: `{'a': [1,2], 'b': [1,2], 'c': [1,2], 'd': [1,2], 'e': [1,2]}.b`,
wanted: CostEstimate{Min: 81, Max: 81},
},
{
// Estimated cost does not track the sizes of nested aggregate types
// (lists, maps, ...) and so assumes a worst case cost when an
// expression applies a comprehension to a nested aggregated type,
// even if the size information is available.
// TODO: This should be fixed.
name: "comprehension on nested list",
expr: `[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]].all(y, y.all(y, y == 1))`,
wanted: CostEstimate{Min: 76, Max: 136},
},
{
name: "comprehension on transformed nested list",
expr: `[1,2,3,4,5].map(x, [x, x]).all(y, y.all(y, y == 1))`,
wanted: CostEstimate{Min: 157, Max: 18446744073709551615},
wanted: CostEstimate{Min: 157, Max: 217},
},
{
// Make sure we're accounting for not just the iteration range size,
// but also the overall comprehension size. The chained map calls
// will treat the result of one map as the iteration range of the other,
// so they're planned in reverse; however, the `+` should verify that
// the comprehension result has a size.
name: "comprehension size",
name: "comprehension on nested literal list",
expr: `["a", "ab", "abc", "abcd", "abcde"].map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`,
wanted: CostEstimate{Min: 157, Max: 217},
},
{
name: "comprehension on nested variable list",
expr: `input.map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`,
vars: []*decls.VariableDecl{decls.NewVariable("input", types.NewListType(types.StringType))},
hints: map[string]uint64{
"input": 5,
"input.@items": 10,
},
wanted: CostEstimate{Min: 13, Max: 208},
},
{
name: "comprehension chaining with concat",
expr: `[1,2,3,4,5].map(x, x).map(x, x) + [1]`,
wanted: CostEstimate{Min: 173, Max: 173},
},
Expand All @@ -568,9 +613,25 @@ func TestCost(t *testing.T) {
expr: `[1,2,3].all(i, i in [1,2,3].map(j, j + j))`,
wanted: CostEstimate{Min: 20, Max: 230},
},
{
name: "nested dyn comprehension",
expr: `dyn([1,2,3]).all(i, i in dyn([1,2,3]).map(j, j + j))`,
wanted: CostEstimate{Min: 21, Max: 234},
},
{
name: "literal map access",
expr: `{'hello': 'hi'}['hello'] != {'hello': 'bye'}['hello']`,
wanted: CostEstimate{Min: 63, Max: 63},
},
{
name: "literal list access",
expr: `['hello', 'hi'][0] != ['hello', 'bye'][1]`,
wanted: CostEstimate{Min: 23, Max: 23},
},
}

for _, tc := range cases {
for _, tst := range cases {
tc := tst
t.Run(tc.name, func(t *testing.T) {
if tc.hints == nil {
tc.hints = map[string]uint64{}
Expand Down
105 changes: 75 additions & 30 deletions ext/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,56 +20,101 @@ import (
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)

var bindingTests = []struct {
expr string
parseOnly bool
name string
expr string
vars []cel.EnvOption
in map[string]any
hints map[string]uint64
estimatedCost checker.CostEstimate
actualCost uint64
}{
{expr: `cel.bind(a, 'hell' + 'o' + '!', [a, a, a].join(', ')) ==
['hell' + 'o' + '!', 'hell' + 'o' + '!', 'hell' + 'o' + '!'].join(', ')`},
// Variable shadowing
{expr: `cel.bind(a,
cel.bind(a, 'world', a + '!'),
'hello ' + a) == 'hello ' + 'world' + '!'`},
{
name: "single bind",
expr: `cel.bind(a, 'hell' + 'o' + '!', "%s, %s, %s".format([a, a, a])) ==
'hello!, hello!, hello' + '!'`,
estimatedCost: checker.CostEstimate{Min: 30, Max: 32},
actualCost: 32,
},
{
name: "multiple binds",
expr: `cel.bind(a, 'hello!',
cel.bind(b, 'goodbye',
a + ' and, ' + b)) == 'hello! and, goodbye'`,
estimatedCost: checker.CostEstimate{Min: 27, Max: 28},
actualCost: 28,
},
{
name: "shadow binds",
expr: `cel.bind(a,
cel.bind(a, 'world', a + '!'),
'hello ' + a) == 'hello ' + 'world' + '!'`,
estimatedCost: checker.CostEstimate{Min: 30, Max: 31},
actualCost: 31,
},
{
name: "nested bind with int list",
expr: `cel.bind(a, x,
cel.bind(b, a[0],
cel.bind(c, a[1], b + c))) == 10`,
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))},
in: map[string]any{
"x": []int64{3, 7},
},
hints: map[string]uint64{
"x": 3,
},
estimatedCost: checker.CostEstimate{Min: 39, Max: 39},
actualCost: 39,
},
{
name: "nested bind with string list",
expr: `cel.bind(a, x,
cel.bind(b, a[0],
cel.bind(c, a[1], b + c))) == "threeseven"`,
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.StringType))},
in: map[string]any{
"x": []string{"three", "seven"},
},
hints: map[string]uint64{
"x": 3,
"x.@items": 10,
},
estimatedCost: checker.CostEstimate{Min: 38, Max: 40},
actualCost: 39,
},
}

func TestBindings(t *testing.T) {
env, err := cel.NewEnv(Bindings(BindingsVersion(0)), Strings())
if err != nil {
t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err)
}
for i, tst := range bindingTests {
for _, tst := range bindingTests {
tc := tst
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var asts []*cel.Ast
opts := append([]cel.EnvOption{Bindings(BindingsVersion(0)), Strings()}, tc.vars...)
env, err := cel.NewEnv(opts...)
if err != nil {
t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err)
}
pAst, iss := env.Parse(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, pAst)
if !tc.parseOnly {
cAst, iss := env.Check(pAst)
if iss.Err() != nil {
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
}
asts = append(asts, cAst)
cAst, iss := env.Check(pAst)
if iss.Err() != nil {
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
}
testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost)
asts = append(asts, cAst)
for _, ast := range asts {
prg, err := env.Program(ast)
if err != nil {
t.Fatal(err)
}
out, _, err := prg.Eval(cel.NoVars())
if err != nil {
t.Fatal(err)
} else if out.Value() != true {
t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr)
}
testEvalWithCost(t, env, ast, tc.in, tc.actualCost)
}
})
}
Expand Down
Loading