Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow cel.bind to be lazily evaluated #221

Merged
1 commit merged into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/src/main/java/dev/cel/common/CelOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ public abstract class CelOptions {

public abstract boolean enableCelValue();

public abstract boolean enableComprehensionLazyEval();

public abstract int comprehensionMaxIterations();

public abstract Builder toBuilder();
Expand Down Expand Up @@ -179,6 +181,7 @@ public static Builder newBuilder() {
.resolveTypeDependencies(true)
.enableUnknownTracking(false)
.enableCelValue(false)
.enableComprehensionLazyEval(false)
.comprehensionMaxIterations(-1);
}

Expand Down Expand Up @@ -452,6 +455,12 @@ public abstract static class Builder {
*/
public abstract Builder comprehensionMaxIterations(int value);

/**
* Enables certain comprehension expressions to be lazily evaluated where safe. Currently, this
* only works for cel.bind.
*/
public abstract Builder enableComprehensionLazyEval(boolean value);

public abstract CelOptions build();
}
}
1 change: 1 addition & 0 deletions extensions/src/test/java/dev/cel/extensions/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ java_library(
"//common/resources/testdata/proto2:messages_extensions_proto2_java_proto",
"//common/resources/testdata/proto2:messages_proto2_java_proto",
"//common/resources/testdata/proto2:test_all_types_java_proto",
"//common/resources/testdata/proto3:test_all_types_java_proto",
"//common/types",
"//common/types:type_providers",
"//compiler",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelValidationException;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
import dev.cel.compiler.CelCompiler;
import dev.cel.compiler.CelCompilerFactory;
import dev.cel.parser.CelStandardMacro;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntime.CelFunctionBinding;
import dev.cel.runtime.CelRuntimeFactory;
import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;

Expand All @@ -45,25 +52,46 @@ public final class CelBindingsExtensionsTest {

private static final CelRuntime RUNTIME = CelRuntimeFactory.standardCelRuntimeBuilder().build();

private enum BindingTestCase {
BOOL_LITERAL("cel.bind(t, true, t)"),
STRING_CONCAT("cel.bind(msg, \"hello\", msg + msg + msg) == \"hellohellohello\""),
NESTED_BINDS("cel.bind(t1, true, cel.bind(t2, true, t1 && t2))"),
NESTED_BINDS_SPECIFIER_ONLY(
"cel.bind(x, cel.bind(x, \"a\", x + x), x + \":\" + x) == \"aa:aa\""),
NESTED_BINDS_SPECIFIER_AND_VALUE(
"cel.bind(x, cel.bind(x, \"a\", x + x), cel.bind(y, x + x, y + \":\" + y)) =="
+ " \"aaaa:aaaa\""),
BIND_WITH_EXISTS_TRUE(
"cel.bind(valid_elems, [1, 2, 3], [3, 4, 5].exists(e, e in valid_elems))"),
BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))");

private final String source;

BindingTestCase(String source) {
this.source = source;
}
}

@Test
@TestParameters("{expr: 'cel.bind(t, true, t)', expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(msg, \"hello\", msg + msg + msg) == \"hellohellohello\"',"
+ " expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(t1, true, cel.bind(t2, true, t1 && t2))', expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(valid_elems, [1, 2, 3], [3, 4, 5]"
+ ".exists(e, e in valid_elems))', expectedResult: true}")
@TestParameters(
"{expr: 'cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))',"
+ " expectedResult: true}")
public void binding_success(String expr, boolean expectedResult) throws Exception {
CelAbstractSyntaxTree ast = COMPILER.compile(expr).getAst();
public void binding_success(@TestParameter BindingTestCase testCase) throws Exception {
CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst();
CelRuntime.Program program = RUNTIME.createProgram(ast);
Object evaluatedResult = program.eval();
boolean evaluatedResult = (boolean) program.eval();

assertThat(evaluatedResult).isTrue();
}

@Test
public void binding_lazyEval_success(@TestParameter BindingTestCase testCase) throws Exception {
CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst();
CelRuntime.Program program =
CelRuntimeFactory.standardCelRuntimeBuilder()
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.build()
.createProgram(ast);
boolean evaluatedResult = (boolean) program.eval();

assertThat(evaluatedResult).isEqualTo(expectedResult);
assertThat(evaluatedResult).isTrue();
}

@Test
Expand Down Expand Up @@ -105,4 +133,113 @@ public void binding_throwsCompilationException(String expr) throws Exception {

assertThat(e).hasMessageThat().contains("cel.bind() variable name must be a simple identifier");
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyBinding_bindingVarNeverReferenced() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.setStandardMacros(CelStandardMacro.HAS)
.addMessageTypes(TestAllTypes.getDescriptor())
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
celCompiler.compile("cel.bind(t, get_true(), has(msg.single_int64) ? t : false)").getAst();

boolean result =
(boolean)
celRuntime
.createProgram(ast)
.eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance()));

assertThat(result).isFalse();
assertThat(invocation.get()).isEqualTo(0);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyBinding_accuInitEvaluatedOnce() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
celCompiler.compile("cel.bind(t, get_true(), t && t && t && t)").getAst();

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyBinding_withNestedBinds() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
.build();
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.setOptions(CelOptions.current().enableComprehensionLazyEval(true).build())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
celCompiler
.compile("cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))")
.getAst();

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ private static CelBuilder newCelBuilder() {
.setContainer("dev.cel.testing.testdata.proto3")
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(
CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build())
CelOptions.current()
.enableTimestampEpoch(true)
.enableComprehensionLazyEval(true)
.populateMacroCalls(true)
.build())
.addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings())
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.addFunctionDeclarations(
Expand Down
77 changes: 68 additions & 9 deletions runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,27 @@ private IntermediateResult resolveIdent(ExecutionFrame frame, CelExpr expr, Stri
return IntermediateResult.create(typeValue);
}

IntermediateResult cachedResult = frame.lookupLazilyEvaluatedResult(name).orElse(null);
if (cachedResult != null) {
return cachedResult;
}

IntermediateResult rawResult = frame.resolveSimpleName(name, expr.id());
Object value = rawResult.value();
boolean isLazyExpression = value instanceof LazyExpression;
if (isLazyExpression) {
value = evalInternal(frame, ((LazyExpression) value).celExpr).value();
}

// Value resolved from Binding, it could be Message, PartialMessage or unbound(null)
Object value = InterpreterUtil.strict(typeProvider.adapt(rawResult.value()));
return IntermediateResult.create(rawResult.attribute(), value);
value = InterpreterUtil.strict(typeProvider.adapt(value));
IntermediateResult result = IntermediateResult.create(rawResult.attribute(), value);

if (isLazyExpression) {
frame.cacheLazilyEvaluatedResult(name, result);
}

return result;
}

private IntermediateResult evalSelect(ExecutionFrame frame, CelExpr expr, CelSelect selectExpr)
Expand Down Expand Up @@ -404,7 +420,7 @@ private IntermediateResult evalCall(ExecutionFrame frame, CelExpr expr, CelCall
return IntermediateResult.create(attr, unknowns.get());
}

Object[] argArray = Arrays.stream(argResults).map(v -> v.value()).toArray();
Object[] argArray = Arrays.stream(argResults).map(IntermediateResult::value).toArray();

return IntermediateResult.create(
attr,
Expand Down Expand Up @@ -754,11 +770,12 @@ private IntermediateResult evalStruct(
fields.put(entry.fieldKey(), value);
}

Optional<Object> unknowns = argChecker.maybeUnknowns();
if (unknowns.isPresent()) {
return IntermediateResult.create(unknowns.get());
}
return IntermediateResult.create(typeProvider.createMessage(reference.name(), fields));
return argChecker
.maybeUnknowns()
.map(IntermediateResult::create)
.orElseGet(
() ->
IntermediateResult.create(typeProvider.createMessage(reference.name(), fields)));
}

// Evaluates the expression and returns a value-or-throwable.
Expand Down Expand Up @@ -796,7 +813,12 @@ private IntermediateResult evalComprehension(
.setLocation(metadata, compre.iterRange().id())
.build();
}
IntermediateResult accuValue = evalNonstrictly(frame, compre.accuInit());
IntermediateResult accuValue;
if (celOptions.enableComprehensionLazyEval() && LazyExpression.isLazilyEvaluable(compre)) {
accuValue = IntermediateResult.create(new LazyExpression(compre.accuInit()));
} else {
accuValue = evalNonstrictly(frame, compre.accuInit());
}
int i = 0;
for (Object elem : iterRange) {
frame.incrementIterations();
Expand Down Expand Up @@ -831,11 +853,39 @@ private IntermediateResult evalComprehension(
}
}

/** Contains a CelExpr that is to be lazily evaluated. */
private static class LazyExpression {
private final CelExpr celExpr;

/**
* Checks whether the provided expression can be evaluated lazily then cached. For example, the
* accumulator initializer in `cel.bind` macro is a good candidate because it never needs to be
* updated after being evaluated once.
*/
private static boolean isLazilyEvaluable(CelComprehension comprehension) {
// For now, just handle cel.bind. cel.block will be a future addition.
return comprehension
.loopCondition()
.constantOrDefault()
.getKind()
.equals(CelConstant.Kind.BOOLEAN_VALUE)
&& !comprehension.loopCondition().constant().booleanValue()
&& comprehension.iterVar().equals("#unused")
&& comprehension.iterRange().exprKind().getKind().equals(ExprKind.Kind.CREATE_LIST)
&& comprehension.iterRange().createList().elements().isEmpty();
}

private LazyExpression(CelExpr celExpr) {
this.celExpr = celExpr;
}
}

/** This class tracks the state meaningful to a single evaluation pass. */
private static class ExecutionFrame {
private final CelEvaluationListener evaluationListener;
private final int maxIterations;
private final ArrayDeque<RuntimeUnknownResolver> resolvers;
private final Map<String, IntermediateResult> lazyEvalResultCache;
private RuntimeUnknownResolver currentResolver;
private int iterations;

Expand All @@ -848,6 +898,7 @@ private ExecutionFrame(
this.resolvers.add(resolver);
this.currentResolver = resolver;
this.maxIterations = maxIterations;
this.lazyEvalResultCache = new HashMap<>();
}

private CelEvaluationListener getEvaluationListener() {
Expand Down Expand Up @@ -878,6 +929,14 @@ private Optional<Object> resolveAttribute(CelAttribute attr) {
return currentResolver.resolveAttribute(attr);
}

private Optional<IntermediateResult> lookupLazilyEvaluatedResult(String name) {
return Optional.ofNullable(lazyEvalResultCache.get(name));
}

private void cacheLazilyEvaluatedResult(String name, IntermediateResult result) {
lazyEvalResultCache.put(name, result);
}

private void pushScope(ImmutableMap<String, IntermediateResult> scope) {
RuntimeUnknownResolver scopedResolver = currentResolver.withScope(scope);
currentResolver = scopedResolver;
Expand Down
Loading