diff --git a/common/src/main/java/dev/cel/common/CelOptions.java b/common/src/main/java/dev/cel/common/CelOptions.java index eaa812e2..2032752a 100644 --- a/common/src/main/java/dev/cel/common/CelOptions.java +++ b/common/src/main/java/dev/cel/common/CelOptions.java @@ -89,6 +89,8 @@ public abstract class CelOptions { public abstract boolean enableCelValue(); + public abstract boolean enableComprehensionLazyEval(); + public abstract int comprehensionMaxIterations(); public abstract Builder toBuilder(); @@ -179,6 +181,7 @@ public static Builder newBuilder() { .resolveTypeDependencies(true) .enableUnknownTracking(false) .enableCelValue(false) + .enableComprehensionLazyEval(false) .comprehensionMaxIterations(-1); } @@ -452,6 +455,12 @@ public abstract static class Builder { */ public abstract Builder comprehensionMaxIterations(int value); + /** + * Enables certain comprehension expressions to be lazily evaluated where safe. Currently, this + * only works for cel.bind. + */ + public abstract Builder enableComprehensionLazyEval(boolean value); + public abstract CelOptions build(); } } diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index b5b40d07..c71df2d8 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -15,6 +15,7 @@ java_library( "//common/resources/testdata/proto2:messages_extensions_proto2_java_proto", "//common/resources/testdata/proto2:messages_proto2_java_proto", "//common/resources/testdata/proto2:test_all_types_java_proto", + "//common/resources/testdata/proto3:test_all_types_java_proto", "//common/types", "//common/types:type_providers", "//compiler", diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index 2b836dcc..65657fbd 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -17,20 +17,27 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; import dev.cel.compiler.CelCompiler; import dev.cel.compiler.CelCompilerFactory; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelRuntime; import dev.cel.runtime.CelRuntime.CelFunctionBinding; import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -45,25 +52,46 @@ public final class CelBindingsExtensionsTest { private static final CelRuntime RUNTIME = CelRuntimeFactory.standardCelRuntimeBuilder().build(); + private enum BindingTestCase { + BOOL_LITERAL("cel.bind(t, true, t)"), + STRING_CONCAT("cel.bind(msg, \"hello\", msg + msg + msg) == \"hellohellohello\""), + NESTED_BINDS("cel.bind(t1, true, cel.bind(t2, true, t1 && t2))"), + NESTED_BINDS_SPECIFIER_ONLY( + "cel.bind(x, cel.bind(x, \"a\", x + x), x + \":\" + x) == \"aa:aa\""), + NESTED_BINDS_SPECIFIER_AND_VALUE( + "cel.bind(x, cel.bind(x, \"a\", x + x), cel.bind(y, x + x, y + \":\" + y)) ==" + + " \"aaaa:aaaa\""), + BIND_WITH_EXISTS_TRUE( + "cel.bind(valid_elems, [1, 2, 3], [3, 4, 5].exists(e, e in valid_elems))"), + BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))"); + + private final String source; + + BindingTestCase(String source) { + this.source = source; + } + } + @Test - @TestParameters("{expr: 'cel.bind(t, true, t)', expectedResult: true}") - @TestParameters( - "{expr: 'cel.bind(msg, \"hello\", msg + msg + msg) == \"hellohellohello\"'," - + " expectedResult: true}") - @TestParameters( - "{expr: 'cel.bind(t1, true, cel.bind(t2, true, t1 && t2))', expectedResult: true}") - @TestParameters( - "{expr: 'cel.bind(valid_elems, [1, 2, 3], [3, 4, 5]" - + ".exists(e, e in valid_elems))', expectedResult: true}") - @TestParameters( - "{expr: 'cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))'," - + " expectedResult: true}") - public void binding_success(String expr, boolean expectedResult) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst(); + public void binding_success(@TestParameter BindingTestCase testCase) throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst(); CelRuntime.Program program = RUNTIME.createProgram(ast); - Object evaluatedResult = program.eval(); + boolean evaluatedResult = (boolean) program.eval(); + + assertThat(evaluatedResult).isTrue(); + } + + @Test + public void binding_lazyEval_success(@TestParameter BindingTestCase testCase) throws Exception { + CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst(); + CelRuntime.Program program = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setOptions(CelOptions.current().enableComprehensionLazyEval(true).build()) + .build() + .createProgram(ast); + boolean evaluatedResult = (boolean) program.eval(); - assertThat(evaluatedResult).isEqualTo(expectedResult); + assertThat(evaluatedResult).isTrue(); } @Test @@ -105,4 +133,113 @@ public void binding_throwsCompilationException(String expr) throws Exception { assertThat(e).hasMessageThat().contains("cel.bind() variable name must be a simple identifier"); } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyBinding_bindingVarNeverReferenced() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setStandardMacros(CelStandardMacro.HAS) + .addMessageTypes(TestAllTypes.getDescriptor()) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .addLibraries(CelExtensions.bindings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + .build(); + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .setOptions(CelOptions.current().enableComprehensionLazyEval(true).build()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + celCompiler.compile("cel.bind(t, get_true(), has(msg.single_int64) ? t : false)").getAst(); + + boolean result = + (boolean) + celRuntime + .createProgram(ast) + .eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); + + assertThat(result).isFalse(); + assertThat(invocation.get()).isEqualTo(0); + } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyBinding_accuInitEvaluatedOnce() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addLibraries(CelExtensions.bindings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + .build(); + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setOptions(CelOptions.current().enableComprehensionLazyEval(true).build()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + celCompiler.compile("cel.bind(t, get_true(), t && t && t && t)").getAst(); + + boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + assertThat(result).isTrue(); + assertThat(invocation.get()).isEqualTo(1); + } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyBinding_withNestedBinds() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addLibraries(CelExtensions.bindings()) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + .build(); + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .setOptions(CelOptions.current().enableComprehensionLazyEval(true).build()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + celCompiler + .compile("cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))") + .getAst(); + + boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + assertThat(result).isTrue(); + assertThat(invocation.get()).isEqualTo(2); + } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 9b516bdd..3ff4a249 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -86,7 +86,11 @@ private static CelBuilder newCelBuilder() { .setContainer("dev.cel.testing.testdata.proto3") .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .setOptions( - CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build()) + CelOptions.current() + .enableTimestampEpoch(true) + .enableComprehensionLazyEval(true) + .populateMacroCalls(true) + .build()) .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) .addFunctionDeclarations( diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index 44832a2e..b0f18d87 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -275,11 +275,27 @@ private IntermediateResult resolveIdent(ExecutionFrame frame, CelExpr expr, Stri return IntermediateResult.create(typeValue); } + IntermediateResult cachedResult = frame.lookupLazilyEvaluatedResult(name).orElse(null); + if (cachedResult != null) { + return cachedResult; + } + IntermediateResult rawResult = frame.resolveSimpleName(name, expr.id()); + Object value = rawResult.value(); + boolean isLazyExpression = value instanceof LazyExpression; + if (isLazyExpression) { + value = evalInternal(frame, ((LazyExpression) value).celExpr).value(); + } // Value resolved from Binding, it could be Message, PartialMessage or unbound(null) - Object value = InterpreterUtil.strict(typeProvider.adapt(rawResult.value())); - return IntermediateResult.create(rawResult.attribute(), value); + value = InterpreterUtil.strict(typeProvider.adapt(value)); + IntermediateResult result = IntermediateResult.create(rawResult.attribute(), value); + + if (isLazyExpression) { + frame.cacheLazilyEvaluatedResult(name, result); + } + + return result; } private IntermediateResult evalSelect(ExecutionFrame frame, CelExpr expr, CelSelect selectExpr) @@ -404,7 +420,7 @@ private IntermediateResult evalCall(ExecutionFrame frame, CelExpr expr, CelCall return IntermediateResult.create(attr, unknowns.get()); } - Object[] argArray = Arrays.stream(argResults).map(v -> v.value()).toArray(); + Object[] argArray = Arrays.stream(argResults).map(IntermediateResult::value).toArray(); return IntermediateResult.create( attr, @@ -754,11 +770,12 @@ private IntermediateResult evalStruct( fields.put(entry.fieldKey(), value); } - Optional unknowns = argChecker.maybeUnknowns(); - if (unknowns.isPresent()) { - return IntermediateResult.create(unknowns.get()); - } - return IntermediateResult.create(typeProvider.createMessage(reference.name(), fields)); + return argChecker + .maybeUnknowns() + .map(IntermediateResult::create) + .orElseGet( + () -> + IntermediateResult.create(typeProvider.createMessage(reference.name(), fields))); } // Evaluates the expression and returns a value-or-throwable. @@ -796,7 +813,12 @@ private IntermediateResult evalComprehension( .setLocation(metadata, compre.iterRange().id()) .build(); } - IntermediateResult accuValue = evalNonstrictly(frame, compre.accuInit()); + IntermediateResult accuValue; + if (celOptions.enableComprehensionLazyEval() && LazyExpression.isLazilyEvaluable(compre)) { + accuValue = IntermediateResult.create(new LazyExpression(compre.accuInit())); + } else { + accuValue = evalNonstrictly(frame, compre.accuInit()); + } int i = 0; for (Object elem : iterRange) { frame.incrementIterations(); @@ -831,11 +853,39 @@ private IntermediateResult evalComprehension( } } + /** Contains a CelExpr that is to be lazily evaluated. */ + private static class LazyExpression { + private final CelExpr celExpr; + + /** + * Checks whether the provided expression can be evaluated lazily then cached. For example, the + * accumulator initializer in `cel.bind` macro is a good candidate because it never needs to be + * updated after being evaluated once. + */ + private static boolean isLazilyEvaluable(CelComprehension comprehension) { + // For now, just handle cel.bind. cel.block will be a future addition. + return comprehension + .loopCondition() + .constantOrDefault() + .getKind() + .equals(CelConstant.Kind.BOOLEAN_VALUE) + && !comprehension.loopCondition().constant().booleanValue() + && comprehension.iterVar().equals("#unused") + && comprehension.iterRange().exprKind().getKind().equals(ExprKind.Kind.CREATE_LIST) + && comprehension.iterRange().createList().elements().isEmpty(); + } + + private LazyExpression(CelExpr celExpr) { + this.celExpr = celExpr; + } + } + /** This class tracks the state meaningful to a single evaluation pass. */ private static class ExecutionFrame { private final CelEvaluationListener evaluationListener; private final int maxIterations; private final ArrayDeque resolvers; + private final Map lazyEvalResultCache; private RuntimeUnknownResolver currentResolver; private int iterations; @@ -848,6 +898,7 @@ private ExecutionFrame( this.resolvers.add(resolver); this.currentResolver = resolver; this.maxIterations = maxIterations; + this.lazyEvalResultCache = new HashMap<>(); } private CelEvaluationListener getEvaluationListener() { @@ -878,6 +929,14 @@ private Optional resolveAttribute(CelAttribute attr) { return currentResolver.resolveAttribute(attr); } + private Optional lookupLazilyEvaluatedResult(String name) { + return Optional.ofNullable(lazyEvalResultCache.get(name)); + } + + private void cacheLazilyEvaluatedResult(String name, IntermediateResult result) { + lazyEvalResultCache.put(name, result); + } + private void pushScope(ImmutableMap scope) { RuntimeUnknownResolver scopedResolver = currentResolver.withScope(scope); currentResolver = scopedResolver;