From c916a111a2d75f1fa6b047a50991170d81bcb345 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Mon, 12 Feb 2024 11:39:14 -0800 Subject: [PATCH] Add capability to evaluate cel.block calls in the runtime PiperOrigin-RevId: 606316063 --- .../dev/cel/optimizer/optimizers/BUILD.bazel | 3 + .../optimizers/SubexpressionOptimizer.java | 15 ++ .../dev/cel/optimizer/optimizers/BUILD.bazel | 3 + .../SubexpressionOptimizerTest.java | 252 ++++++++++++++++++ .../dev/cel/runtime/DefaultInterpreter.java | 22 +- 5 files changed, 292 insertions(+), 3 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 10096a73..77b3d289 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -42,8 +42,11 @@ java_library( "//bundle:cel", "//checker:checker_legacy_environment", "//common", + "//common:compiler_common", "//common/ast", "//common/navigation", + "//common/types", + "//common/types:type_providers", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", "//parser:operator", diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 8d057c74..241680bb 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -19,12 +19,15 @@ import static java.util.Arrays.stream; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; import dev.cel.bundle.Cel; import dev.cel.checker.Standard; import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelSource; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelIdent; @@ -32,6 +35,9 @@ import dev.cel.common.navigation.CelNavigableAst; import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder; +import dev.cel.common.types.CelType; +import dev.cel.common.types.ListType; +import dev.cel.common.types.SimpleType; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.MutableAst; import dev.cel.parser.Operator; @@ -64,6 +70,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer { new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); private static final String BIND_IDENTIFIER_PREFIX = "@r"; private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c"; + private static final String CEL_BLOCK_FUNCTION = "cel.@block"; private static final ImmutableSet CSE_ALLOWED_FUNCTIONS = Streams.concat( stream(Operator.values()).map(Operator::getFunction), @@ -325,6 +332,14 @@ private CelExpr normalizeForEquality(CelExpr celExpr) { return mutableAst.clearExprIds(celExpr); } + @VisibleForTesting + static CelFunctionDecl newCelBlockFunctionDecl(CelType resultType) { + return CelFunctionDecl.newFunctionDeclaration( + CEL_BLOCK_FUNCTION, + CelOverloadDecl.newGlobalOverload( + "cel_block_list", resultType, ListType.create(SimpleType.DYN), resultType)); + } + /** Options to configure how Common Subexpression Elimination behave. */ @AutoValue public abstract static class SubexpressionOptimizerOptions { diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index a4f8143f..c907ef1f 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -13,11 +13,13 @@ java_library( "//common:compiler_common", "//common:options", "//common/ast", + "//common/navigation", "//common/resources/testdata/proto3:test_all_types_java_proto", "//common/types", "//extensions", "//extensions:optional_library", "//optimizer", + "//optimizer:mutable_ast", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", "//optimizer/optimizers:common_subexpression_elimination", @@ -25,6 +27,7 @@ java_library( "//parser:macro", "//parser:operator", "//parser:unparser", + "//runtime", "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:junit_junit", diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 9b516bdd..70b12f53 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -29,8 +29,15 @@ import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.CelValidationException; +import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.navigation.CelNavigableAst; +import dev.cel.common.navigation.CelNavigableExpr; +import dev.cel.common.types.ListType; import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; @@ -39,14 +46,19 @@ import dev.cel.optimizer.CelOptimizationException; import dev.cel.optimizer.CelOptimizer; import dev.cel.optimizer.CelOptimizerFactory; +import dev.cel.optimizer.MutableAst; import dev.cel.optimizer.optimizers.SubexpressionOptimizer.SubexpressionOptimizerOptions; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; import dev.cel.parser.Operator; +import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntime.CelFunctionBinding; +import dev.cel.runtime.CelRuntimeFactory; import dev.cel.testing.testdata.proto3.TestAllTypesProto.NestedTestAllTypes; import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; import org.junit.runner.RunWith; @@ -55,6 +67,37 @@ public class SubexpressionOptimizerTest { private static final Cel CEL = newCelBuilder().build(); + private static final Cel CEL_FOR_EVALUATING_BLOCK = + CelFactory.standardCelBuilder() + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFunctionDeclarations( + // These are test only declarations, as the actual function is made internal using @ + // symbol. + // If the main function declaration needs updating, be sure to update the test + // declaration as well. + CelFunctionDecl.newFunctionDeclaration( + "cel.block", + CelOverloadDecl.newGlobalOverload( + "block_test_only_overload", + SimpleType.DYN, + ListType.create(SimpleType.DYN), + SimpleType.DYN)), + SubexpressionOptimizer.newCelBlockFunctionDecl(SimpleType.DYN), + CelFunctionDecl.newFunctionDeclaration( + "get_true", + CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) + // Similarly, this is a test only decl (index0 -> @index0) + .addVarDeclarations( + CelVarDecl.newVarDeclaration("index0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("index2", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index0", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index1", SimpleType.DYN), + CelVarDecl.newVarDeclaration("@index2", SimpleType.DYN)) + .addMessageTypes(TestAllTypes.getDescriptor()) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .build(); + private static final CelOptimizer CEL_OPTIMIZER = CelOptimizerFactory.standardCelOptimizerBuilder(CEL) .addAstOptimizers( @@ -659,4 +702,213 @@ public void iterationLimitReached_throws() throws Exception { .optimize(ast)); assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached."); } + + private enum BlockTestCase { + BOOL_LITERAL("cel.block([true, false], index0 || index1)"), + STRING_CONCAT("cel.block(['a' + 'b', index0 + 'c'], index1 + 'd') == 'abcd'"), + + BLOCK_WITH_EXISTS_TRUE("cel.block([[1, 2, 3], [3, 4, 5].exists(e, e in index0)], index1)"), + BLOCK_WITH_EXISTS_FALSE("cel.block([[1, 2, 3], ![4, 5].exists(e, e in index0)], index1)"), + ; + + private final String source; + + BlockTestCase(String source) { + this.source = source; + } + } + + @Test + public void block_success(@TestParameter BlockTestCase testCase) throws Exception { + CelAbstractSyntaxTree ast = compileUsingInternalFunctions(testCase.source); + + Object evaluatedResult = CEL_FOR_EVALUATING_BLOCK.createProgram(ast).eval(); + + assertThat(evaluatedResult).isNotNull(); + } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyEval_blockIndexNeverReferenced() throws Exception { + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions( + "cel.block([get_true()], has(msg.single_int64) ? index0 : false)"); + + boolean result = + (boolean) + celRuntime + .createProgram(ast) + .eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); + + assertThat(result).isFalse(); + assertThat(invocation.get()).isEqualTo(0); + } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyEval_blockIndexEvaluatedOnlyOnce() throws Exception { + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions("cel.block([get_true()], index0 && index0 && index0)"); + + boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + assertThat(result).isTrue(); + assertThat(invocation.get()).isEqualTo(1); + } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyEval_multipleBlockIndices_inResultExpr() throws Exception { + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions( + "cel.block([get_true(), get_true(), get_true()], index0 && index0 && index1 && index1" + + " && index2 && index2)"); + + boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + assertThat(result).isTrue(); + assertThat(invocation.get()).isEqualTo(3); + } + + @Test + @SuppressWarnings("Immutable") // Test only + public void lazyEval_multipleBlockIndices_cascaded() throws Exception { + AtomicInteger invocation = new AtomicInteger(); + CelRuntime celRuntime = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addFunctionBindings( + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + })) + .build(); + CelAbstractSyntaxTree ast = + compileUsingInternalFunctions("cel.block([get_true(), index0, index1], index2)"); + + boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + assertThat(result).isTrue(); + assertThat(invocation.get()).isEqualTo(1); + } + + @Test + @TestParameters("{source: 'cel.block([])'}") + @TestParameters("{source: 'cel.block([1])'}") + @TestParameters("{source: 'cel.block(1, 2)'}") + @TestParameters("{source: 'cel.block(1, [1])'}") + public void block_invalidArguments_throws(String source) { + CelValidationException e = + assertThrows(CelValidationException.class, () -> compileUsingInternalFunctions(source)); + + assertThat(e).hasMessageThat().contains("found no matching overload for 'cel.block'"); + } + + @Test + public void blockIndex_invalidArgument_throws() { + CelValidationException e = + assertThrows( + CelValidationException.class, + () -> compileUsingInternalFunctions("cel.block([1], index)")); + + assertThat(e).hasMessageThat().contains("undeclared reference"); + } + + /** + * Converts AST containing cel.block related test functions to internal functions (e.g: cel.block + * -> cel.@block) + */ + private static CelAbstractSyntaxTree compileUsingInternalFunctions(String expression) + throws CelValidationException { + MutableAst mutableAst = MutableAst.newInstance(1000); + CelAbstractSyntaxTree astToModify = CEL_FOR_EVALUATING_BLOCK.compile(expression).getAst(); + while (true) { + CelExpr celExpr = + CelNavigableAst.fromAst(astToModify) + .getRoot() + .allNodes() + .filter(node -> node.getKind().equals(Kind.CALL)) + .map(CelNavigableExpr::expr) + .filter(expr -> expr.call().function().equals("cel.block")) + .findAny() + .orElse(null); + if (celExpr == null) { + break; + } + astToModify = + mutableAst.replaceSubtree( + astToModify, + celExpr.toBuilder() + .setCall(celExpr.call().toBuilder().setFunction("cel.@block").build()) + .build(), + celExpr.id()); + } + + while (true) { + CelExpr celExpr = + CelNavigableAst.fromAst(astToModify) + .getRoot() + .allNodes() + .filter(node -> node.getKind().equals(Kind.IDENT)) + .map(CelNavigableExpr::expr) + .filter(expr -> expr.ident().name().startsWith("index")) + .findAny() + .orElse(null); + if (celExpr == null) { + break; + } + String internalIdentName = "@" + celExpr.ident().name(); + astToModify = + mutableAst.replaceSubtree( + astToModify, + celExpr.toBuilder() + .setIdent(celExpr.ident().toBuilder().setName(internalIdentName).build()) + .build(), + celExpr.id()); + } + + return CEL_FOR_EVALUATING_BLOCK.check(astToModify).getAst(); + } } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index f27882ed..3c4c0003 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -34,7 +34,6 @@ import dev.cel.common.ast.CelExpr.CelCreateList; import dev.cel.common.ast.CelExpr.CelCreateMap; import dev.cel.common.ast.CelExpr.CelCreateStruct; -import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.CelSelect; import dev.cel.common.ast.CelExpr.ExprKind; import dev.cel.common.ast.CelReference; @@ -194,7 +193,7 @@ private IntermediateResult evalInternal(ExecutionFrame frame, CelExpr expr) result = IntermediateResult.create(evalConstant(frame, expr, expr.constant())); break; case IDENT: - result = evalIdent(frame, expr, expr.ident()); + result = evalIdent(frame, expr); break; case SELECT: result = evalSelect(frame, expr, expr.select()); @@ -257,7 +256,7 @@ private Object evalConstant( } } - private IntermediateResult evalIdent(ExecutionFrame frame, CelExpr expr, CelIdent unusedIdent) + private IntermediateResult evalIdent(ExecutionFrame frame, CelExpr expr) throws InterpreterException { CelReference reference = ast.getReferenceOrThrow(expr.id()); if (reference.value().isPresent()) { @@ -371,6 +370,8 @@ private IntermediateResult evalCall(ExecutionFrame frame, CelExpr expr, CelCall return result.get(); } break; + case "cel_block_list": + return evalCelBlock(frame, expr, callExpr); default: break; } @@ -846,6 +847,21 @@ private IntermediateResult evalComprehension( frame.popScope(); return result; } + + private IntermediateResult evalCelBlock( + ExecutionFrame frame, CelExpr unusedExpr, CelCall blockCall) throws InterpreterException { + CelCreateList exprList = blockCall.args().get(0).createList(); + ImmutableMap.Builder blockList = ImmutableMap.builder(); + for (int index = 0; index < exprList.elements().size(); index++) { + // Register the block indices as lazily evaluated expressions stored as unique identifiers. + blockList.put( + "@index" + index, + IntermediateResult.create(new LazyExpression(exprList.elements().get(index)))); + } + frame.pushScope(blockList.buildOrThrow()); + + return evalInternal(frame, blockCall.args().get(1)); + } } /** Contains a CelExpr that is to be lazily evaluated. */