Skip to content

Commit

Permalink
Allow cel.bind to be lazily evaluated
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 601512156
  • Loading branch information
l46kok authored and copybara-github committed Jan 25, 2024
1 parent b7823ba commit 90e9b2a
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 26 deletions.
9 changes: 9 additions & 0 deletions common/src/main/java/dev/cel/common/CelOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public abstract class CelOptions {

public abstract boolean enableCelValue();

public abstract boolean enableComprehensionLazyEval();

public abstract int comprehensionMaxIterations();

public abstract Builder toBuilder();
Expand Down Expand Up @@ -179,6 +181,7 @@ public static Builder newBuilder() {
.resolveTypeDependencies(true)
.enableUnknownTracking(false)
.enableCelValue(false)
.enableComprehensionLazyEval(false)
.comprehensionMaxIterations(-1);
}

Expand Down Expand Up @@ -452,6 +455,12 @@ public abstract static class Builder {
*/
public abstract Builder comprehensionMaxIterations(int value);

/**
* Enables certain comprehension expressions to be lazily evaluated where safe. Currently, this
* only works for cel.bind.
*/
public abstract Builder enableComprehensionLazyEval(boolean value);

public abstract CelOptions build();
}
}
1 change: 1 addition & 0 deletions extensions/src/test/java/dev/cel/extensions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ java_library(
"//common/resources/testdata/proto2:messages_extensions_proto2_java_proto",
"//common/resources/testdata/proto2:messages_proto2_java_proto",
"//common/resources/testdata/proto2:test_all_types_java_proto",
"//common/resources/testdata/proto3:test_all_types_java_proto",
"//common/types",
"//common/types:type_providers",
"//compiler",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelValidationException;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
import dev.cel.compiler.CelCompiler;
import dev.cel.compiler.CelCompilerFactory;
import dev.cel.parser.CelStandardMacro;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntime.CelFunctionBinding;
import dev.cel.runtime.CelRuntimeFactory;
import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;

Expand All @@ -45,25 +52,46 @@ public final class CelBindingsExtensionsTest {

private static final CelRuntime RUNTIME = CelRuntimeFactory.standardCelRuntimeBuilder().build();

private enum BindingTestCase {
BOOL_LITERAL("cel.bind(t, true, t)"),
STRING_CONCAT("cel.bind(msg, \"hello\", msg + msg + msg) == \"hellohellohello\""),
NESTED_BINDS("cel.bind(t1, true, cel.bind(t2, true, t1 && t2))"),
NESTED_BINDS_SPECIFIER_ONLY(
"cel.bind(x, cel.bind(x, \"a\", x + x), x + \":\" + x) == \"aa:aa\""),
NESTED_BINDS_SPECIFIER_AND_VALUE(
"cel.bind(x, cel.bind(x, \"a\", x + x), cel.bind(y, x + x, y + \":\" + y)) =="
+ " \"aaaa:aaaa\""),
BIND_WITH_EXISTS_TRUE(
"cel.bind(valid_elems, [1, 2, 3], [3, 4, 5].exists(e, e in valid_elems))"),
BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))");

private final String source;

BindingTestCase(String source) {
this.source = source;
}
}

@Test
@TestParameters("{expr: 'cel.bind(t, true, t)', expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(msg, \"hello\", msg + msg + msg) == \"hellohellohello\"',"
+ " expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(t1, true, cel.bind(t2, true, t1 && t2))', expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(valid_elems, [1, 2, 3], [3, 4, 5]"
+ ".exists(e, e in valid_elems))', expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))',"
+ " expectedResult: true}")
public void binding_success(String expr, boolean expectedResult) throws Exception {
CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst();
public void binding_success(@TestParameter BindingTestCase testCase) throws Exception {
CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst();
CelRuntime.Program program = RUNTIME.createProgram(ast);
Object evaluatedResult = program.eval();
boolean evaluatedResult = (boolean) program.eval();

assertThat(evaluatedResult).isTrue();
}

@Test
public void binding_lazyEval_success(@TestParameter BindingTestCase testCase) throws Exception {
CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst();
CelRuntime.Program program =
CelRuntimeFactory.standardCelRuntimeBuilder()
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.build()
.createProgram(ast);
boolean evaluatedResult = (boolean) program.eval();

assertThat(evaluatedResult).isEqualTo(expectedResult);
assertThat(evaluatedResult).isTrue();
}

@Test
Expand Down Expand Up @@ -105,4 +133,113 @@ public void binding_throwsCompilationException(String expr) throws Exception {

assertThat(e).hasMessageThat().contains("cel.bind() variable name must be a simple identifier");
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyBinding_bindingVarNeverReferenced() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.setStandardMacros(CelStandardMacro.HAS)
.addMessageTypes(TestAllTypes.getDescriptor())
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
celCompiler.compile("cel.bind(t, get_true(), has(msg.single_int64) ? t : false)").getAst();

boolean result =
(boolean)
celRuntime
.createProgram(ast)
.eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance()));

assertThat(result).isFalse();
assertThat(invocation.get()).isEqualTo(0);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyBinding_accuInitEvaluatedOnce() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
celCompiler.compile("cel.bind(t, get_true(), t && t && t && t)").getAst();

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyBinding_withNestedBinds() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
celCompiler
.compile("cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))")
.getAst();

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ private static CelBuilder newCelBuilder() {
.setContainer("dev.cel.testing.testdata.proto3")
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(
CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build())
CelOptions.current()
.enableTimestampEpoch(true)
.enableComprehensionLazyEval(true)
.populateMacroCalls(true)
.build())
.addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings())
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.addFunctionDeclarations(
Expand Down
77 changes: 68 additions & 9 deletions runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,27 @@ private IntermediateResult resolveIdent(ExecutionFrame frame, CelExpr expr, Stri
return IntermediateResult.create(typeValue);
}

IntermediateResult cachedResult = frame.lookupLazilyEvaluatedResult(name).orElse(null);
if (cachedResult != null) {
return cachedResult;
}

IntermediateResult rawResult = frame.resolveSimpleName(name, expr.id());
Object value = rawResult.value();
boolean isLazyExpression = value instanceof LazyExpression;
if (isLazyExpression) {
value = evalInternal(frame, ((LazyExpression) value).celExpr).value();
}

// Value resolved from Binding, it could be Message, PartialMessage or unbound(null)
Object value = InterpreterUtil.strict(typeProvider.adapt(rawResult.value()));
return IntermediateResult.create(rawResult.attribute(), value);
value = InterpreterUtil.strict(typeProvider.adapt(value));
IntermediateResult result = IntermediateResult.create(rawResult.attribute(), value);

if (isLazyExpression) {
frame.cacheLazilyEvaluatedResult(name, result);
}

return result;
}

private IntermediateResult evalSelect(ExecutionFrame frame, CelExpr expr, CelSelect selectExpr)
Expand Down Expand Up @@ -404,7 +420,7 @@ private IntermediateResult evalCall(ExecutionFrame frame, CelExpr expr, CelCall
return IntermediateResult.create(attr, unknowns.get());
}

Object[] argArray = Arrays.stream(argResults).map(v -> v.value()).toArray();
Object[] argArray = Arrays.stream(argResults).map(IntermediateResult::value).toArray();

return IntermediateResult.create(
attr,
Expand Down Expand Up @@ -754,11 +770,12 @@ private IntermediateResult evalStruct(
fields.put(entry.fieldKey(), value);
}

Optional<Object> unknowns = argChecker.maybeUnknowns();
if (unknowns.isPresent()) {
return IntermediateResult.create(unknowns.get());
}
return IntermediateResult.create(typeProvider.createMessage(reference.name(), fields));
return argChecker
.maybeUnknowns()
.map(IntermediateResult::create)
.orElseGet(
() ->
IntermediateResult.create(typeProvider.createMessage(reference.name(), fields)));
}

// Evaluates the expression and returns a value-or-throwable.
Expand Down Expand Up @@ -796,7 +813,12 @@ private IntermediateResult evalComprehension(
.setLocation(metadata, compre.iterRange().id())
.build();
}
IntermediateResult accuValue = evalNonstrictly(frame, compre.accuInit());
IntermediateResult accuValue;
if (celOptions.enableComprehensionLazyEval() && LazyExpression.isLazilyEvaluable(compre)) {
accuValue = IntermediateResult.create(new LazyExpression(compre.accuInit()));
} else {
accuValue = evalNonstrictly(frame, compre.accuInit());
}
int i = 0;
for (Object elem : iterRange) {
frame.incrementIterations();
Expand Down Expand Up @@ -831,11 +853,39 @@ private IntermediateResult evalComprehension(
}
}

/** Contains a CelExpr that is to be lazily evaluated. */
private static class LazyExpression {
private final CelExpr celExpr;

/**
* Checks whether the provided expression can be evaluated lazily then cached. For example, the
* accumulator initializer in `cel.bind` macro is a good candidate because it never needs to be
* updated after being evaluated once.
*/
private static boolean isLazilyEvaluable(CelComprehension comprehension) {
// For now, just handle cel.bind. cel.block will be a future addition.
return comprehension
.loopCondition()
.constantOrDefault()
.getKind()
.equals(CelConstant.Kind.BOOLEAN_VALUE)
&& !comprehension.loopCondition().constant().booleanValue()
&& comprehension.iterVar().equals("#unused")
&& comprehension.iterRange().exprKind().getKind().equals(ExprKind.Kind.CREATE_LIST)
&& comprehension.iterRange().createList().elements().isEmpty();
}

private LazyExpression(CelExpr celExpr) {
this.celExpr = celExpr;
}
}

/** This class tracks the state meaningful to a single evaluation pass. */
private static class ExecutionFrame {
private final CelEvaluationListener evaluationListener;
private final int maxIterations;
private final ArrayDeque<RuntimeUnknownResolver> resolvers;
private final Map<String, IntermediateResult> lazyEvalResultCache;
private RuntimeUnknownResolver currentResolver;
private int iterations;

Expand All @@ -848,6 +898,7 @@ private ExecutionFrame(
this.resolvers.add(resolver);
this.currentResolver = resolver;
this.maxIterations = maxIterations;
this.lazyEvalResultCache = new HashMap<>();
}

private CelEvaluationListener getEvaluationListener() {
Expand Down Expand Up @@ -878,6 +929,14 @@ private Optional<Object> resolveAttribute(CelAttribute attr) {
return currentResolver.resolveAttribute(attr);
}

private Optional<IntermediateResult> lookupLazilyEvaluatedResult(String name) {
return Optional.ofNullable(lazyEvalResultCache.get(name));
}

private void cacheLazilyEvaluatedResult(String name, IntermediateResult result) {
lazyEvalResultCache.put(name, result);
}

private void pushScope(ImmutableMap<String, IntermediateResult> scope) {
RuntimeUnknownResolver scopedResolver = currentResolver.withScope(scope);
currentResolver = scopedResolver;
Expand Down

0 comments on commit 90e9b2a

Please sign in to comment.