Skip to content

Commit

Permalink
Add capability to evaluate cel.block calls in the runtime
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606316063
  • Loading branch information
l46kok authored and copybara-github committed Feb 12, 2024
1 parent 00d7726 commit c916a11
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ java_library(
"//bundle:cel",
"//checker:checker_legacy_environment",
"//common",
"//common:compiler_common",
"//common/ast",
"//common/navigation",
"//common/types",
"//common/types:type_providers",
"//optimizer:ast_optimizer",
"//optimizer:mutable_ast",
"//parser:operator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,25 @@
import static java.util.Arrays.stream;

import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import dev.cel.bundle.Cel;
import dev.cel.checker.Standard;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelSource;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelIdent;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder;
import dev.cel.common.types.CelType;
import dev.cel.common.types.ListType;
import dev.cel.common.types.SimpleType;
import dev.cel.optimizer.CelAstOptimizer;
import dev.cel.optimizer.MutableAst;
import dev.cel.parser.Operator;
Expand Down Expand Up @@ -64,6 +70,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
private static final String BIND_IDENTIFIER_PREFIX = "@r";
private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c";
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
private static final ImmutableSet<String> CSE_ALLOWED_FUNCTIONS =
Streams.concat(
stream(Operator.values()).map(Operator::getFunction),
Expand Down Expand Up @@ -325,6 +332,14 @@ private CelExpr normalizeForEquality(CelExpr celExpr) {
return mutableAst.clearExprIds(celExpr);
}

@VisibleForTesting
static CelFunctionDecl newCelBlockFunctionDecl(CelType resultType) {
return CelFunctionDecl.newFunctionDeclaration(
CEL_BLOCK_FUNCTION,
CelOverloadDecl.newGlobalOverload(
"cel_block_list", resultType, ListType.create(SimpleType.DYN), resultType));
}

/** Options to configure how Common Subexpression Elimination behave. */
@AutoValue
public abstract static class SubexpressionOptimizerOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@ java_library(
"//common:compiler_common",
"//common:options",
"//common/ast",
"//common/navigation",
"//common/resources/testdata/proto3:test_all_types_java_proto",
"//common/types",
"//extensions",
"//extensions:optional_library",
"//optimizer",
"//optimizer:mutable_ast",
"//optimizer:optimization_exception",
"//optimizer:optimizer_builder",
"//optimizer/optimizers:common_subexpression_elimination",
"//optimizer/optimizers:constant_folding",
"//parser:macro",
"//parser:operator",
"//parser:unparser",
"//runtime",
"@maven//:com_google_guava_guava",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"@maven//:junit_junit",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,15 @@
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.types.ListType;
import dev.cel.common.types.OptionalType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
Expand All @@ -39,14 +46,19 @@
import dev.cel.optimizer.CelOptimizationException;
import dev.cel.optimizer.CelOptimizer;
import dev.cel.optimizer.CelOptimizerFactory;
import dev.cel.optimizer.MutableAst;
import dev.cel.optimizer.optimizers.SubexpressionOptimizer.SubexpressionOptimizerOptions;
import dev.cel.parser.CelStandardMacro;
import dev.cel.parser.CelUnparser;
import dev.cel.parser.CelUnparserFactory;
import dev.cel.parser.Operator;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntime.CelFunctionBinding;
import dev.cel.runtime.CelRuntimeFactory;
import dev.cel.testing.testdata.proto3.TestAllTypesProto.NestedTestAllTypes;
import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;

Expand All @@ -55,6 +67,37 @@ public class SubexpressionOptimizerTest {

private static final Cel CEL = newCelBuilder().build();

private static final Cel CEL_FOR_EVALUATING_BLOCK =
CelFactory.standardCelBuilder()
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.addFunctionDeclarations(
// These are test only declarations, as the actual function is made internal using @
// symbol.
// If the main function declaration needs updating, be sure to update the test
// declaration as well.
CelFunctionDecl.newFunctionDeclaration(
"cel.block",
CelOverloadDecl.newGlobalOverload(
"block_test_only_overload",
SimpleType.DYN,
ListType.create(SimpleType.DYN),
SimpleType.DYN)),
SubexpressionOptimizer.newCelBlockFunctionDecl(SimpleType.DYN),
CelFunctionDecl.newFunctionDeclaration(
"get_true",
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
// Similarly, this is a test only decl (index0 -> @index0)
.addVarDeclarations(
CelVarDecl.newVarDeclaration("index0", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index1", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index2", SimpleType.DYN),
CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN),
CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN),
CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN))
.addMessageTypes(TestAllTypes.getDescriptor())
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
.build();

private static final CelOptimizer CEL_OPTIMIZER =
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
.addAstOptimizers(
Expand Down Expand Up @@ -659,4 +702,213 @@ public void iterationLimitReached_throws() throws Exception {
.optimize(ast));
assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached.");
}

private enum BlockTestCase {
BOOL_LITERAL("cel.block([true, false], index0 || index1)"),
STRING_CONCAT("cel.block(['a' + 'b', index0 + 'c'], index1 + 'd') == 'abcd'"),

BLOCK_WITH_EXISTS_TRUE("cel.block([[1, 2, 3], [3, 4, 5].exists(e, e in index0)], index1)"),
BLOCK_WITH_EXISTS_FALSE("cel.block([[1, 2, 3], ![4, 5].exists(e, e in index0)], index1)"),
;

private final String source;

BlockTestCase(String source) {
this.source = source;
}
}

@Test
public void block_success(@TestParameter BlockTestCase testCase) throws Exception {
CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source);

Object evaluatedResult = CEL_FOR_EVALUATING_BLOCK.createProgram(ast).eval();

assertThat(evaluatedResult).isNotNull();
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_blockIndexNeverReferenced() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions(
"cel.block([get_true()], has(msg.single_int64) ? index0 : false)");

boolean result =
(boolean)
celRuntime
.createProgram(ast)
.eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance()));

assertThat(result).isFalse();
assertThat(invocation.get()).isEqualTo(0);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_blockIndexEvaluatedOnlyOnce() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions("cel.block([get_true()], index0 && index0 && index0)");

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_multipleBlockIndices_inResultExpr() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions(
"cel.block([get_true(), get_true(), get_true()], index0 && index0 && index1 && index1"
+ " && index2 && index2)");

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(3);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_multipleBlockIndices_cascaded() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions("cel.block([get_true(), index0, index1], index2)");

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@TestParameters("{source: 'cel.block([])'}")
@TestParameters("{source: 'cel.block([1])'}")
@TestParameters("{source: 'cel.block(1, 2)'}")
@TestParameters("{source: 'cel.block(1, [1])'}")
public void block_invalidArguments_throws(String source) {
CelValidationException e =
assertThrows(CelValidationException.class, () -> compileUsingInternalFunctions(source));

assertThat(e).hasMessageThat().contains("found no matching overload for 'cel.block'");
}

@Test
public void blockIndex_invalidArgument_throws() {
CelValidationException e =
assertThrows(
CelValidationException.class,
() -> compileUsingInternalFunctions("cel.block([1], index)"));

assertThat(e).hasMessageThat().contains("undeclared reference");
}

/**
* Converts AST containing cel.block related test functions to internal functions (e.g: cel.block
* -> cel.@block)
*/
private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expression)
throws CelValidationException {
MutableAst mutableAst = MutableAst.newInstance(1000);
CelAbstractSyntaxTree astToModify = CEL_FOR_EVALUATING_BLOCK.compile(expression).getAst();
while (true) {
CelExpr celExpr =
CelNavigableAst.fromAst(astToModify)
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.CALL))
.map(CelNavigableExpr::expr)
.filter(expr -> expr.call().function().equals("cel.block"))
.findAny()
.orElse(null);
if (celExpr == null) {
break;
}
astToModify =
mutableAst.replaceSubtree(
astToModify,
celExpr.toBuilder()
.setCall(celExpr.call().toBuilder().setFunction("cel.@block").build())
.build(),
celExpr.id());
}

while (true) {
CelExpr celExpr =
CelNavigableAst.fromAst(astToModify)
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.IDENT))
.map(CelNavigableExpr::expr)
.filter(expr -> expr.ident().name().startsWith("index"))
.findAny()
.orElse(null);
if (celExpr == null) {
break;
}
String internalIdentName = "@" + celExpr.ident().name();
astToModify =
mutableAst.replaceSubtree(
astToModify,
celExpr.toBuilder()
.setIdent(celExpr.ident().toBuilder().setName(internalIdentName).build())
.build(),
celExpr.id());
}

return CEL_FOR_EVALUATING_BLOCK.check(astToModify).getAst();
}
}
Loading

0 comments on commit c916a11

Please sign in to comment.