From cd11baf909a722e28045ebdd9411902880102ab2 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 19 Jan 2024 11:14:04 -0800 Subject: [PATCH] Perform CSE on presence tests PiperOrigin-RevId: 599886305 --- .../dev/cel/common/ast/CelExprFormatter.java | 4 +- .../cel/common/ast/CelExprFormatterTest.java | 33 ++++++ .../dev/cel/optimizer/CelAstOptimizer.java | 18 +++ .../java/dev/cel/optimizer/MutableAst.java | 9 ++ .../optimizers/SubexpressionOptimizer.java | 54 ++++++++- .../SubexpressionOptimizerTest.java | 103 +++++++++++++++++- 6 files changed, 209 insertions(+), 12 deletions(-) diff --git a/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java b/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java index 8ab71c09..32e0233d 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java @@ -122,8 +122,8 @@ private void appendSelect(CelExpr.CelSelect celSelect) { indent(); formatExpr(celSelect.operand()); outdent(); - append("."); - append(celSelect.field()); + appendWithoutIndent("."); + appendWithoutIndent(celSelect.field()); if (celSelect.testOnly()) { appendWithoutIndent("~presence_test"); } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java b/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java index 2c863b1f..0efb890c 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java @@ -336,4 +336,37 @@ public void comprehension() throws Exception { + " }\n" + "}"); } + + @Test + public void ternaryWithPresenceTest() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .build(); + CelAbstractSyntaxTree ast = + celCompiler.compile("has(msg.single_any) ? msg.single_any : 10").getAst(); + + String formattedExpr = CelExprFormatter.format(ast.getExpr()); + + assertThat(formattedExpr) + .isEqualTo( + "CALL [5] {\n" + + " function: _?_:_\n" + + " args: {\n" + + " SELECT [4] {\n" + + " IDENT [2] {\n" + + " name: msg\n" + + " }.single_any~presence_test\n" + + " }\n" + + " SELECT [7] {\n" + + " IDENT [6] {\n" + + " name: msg\n" + + " }.single_any\n" + + " }\n" + + " CONSTANT [8] { value: 10 }\n" + + " }\n" + + "}"); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index c7f16646..feac4ce9 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -26,6 +26,24 @@ public interface CelAstOptimizer { CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) throws CelOptimizationException; + /** + * Replaces a subtree in the given expression node. This operation is intended for AST + * optimization purposes. + * + *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision + * between the nodes. The renumbering occurs even if the subtree was not replaced. + * + * @param celExpr Original expression node to rewrite. + * @param newExpr New CelExpr to replace the subtree with. + * @param exprIdToReplace Expression id of the subtree that is getting replaced. + */ + default CelExpr replaceSubtree(CelExpr celExpr, CelExpr newExpr, long exprIdToReplace) { + return MutableAst.replaceSubtree(celExpr, newExpr, exprIdToReplace); + } + /** * Replaces a subtree in the given AST. This operation is intended for AST optimization purposes. * diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 45cbb0ff..1c342716 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -66,6 +66,15 @@ static CelExpr clearExprIds(CelExpr celExpr) { return renumberExprIds((unused) -> 0, celExpr.toBuilder()).build(); } + /** Mutates the given {@link CelExpr} by replacing a subtree at the given index. */ + static CelExpr replaceSubtree(CelExpr expr, CelExpr newExpr, long exprIdToReplace) { + return replaceSubtree( + CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()), + CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()), + exprIdToReplace) + .getExpr(); + } + /** * Mutates the given AST by replacing a subtree at a given index. * diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index d2f1ae7e..d7eda5b4 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -97,7 +97,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { int bindIdentifierIndex = 0; int iterCount; for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) { - CelNavigableExpr cseCandidate = findCseCandidate(astToModify).orElse(null); + CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null); if (cseCandidate == null) { break; } @@ -107,7 +107,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { // Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time. ImmutableList allCseCandidates = - getAllCseCandidatesStream(astToModify, cseCandidate.expr()).collect(toImmutableList()); + getAllCseCandidatesStream(astToModify, cseCandidate).collect(toImmutableList()); // Replace all CSE candidates with new bind identifier for (CelExpr semanticallyEqualNode : allCseCandidates) { @@ -142,7 +142,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { // Insert the new bind call astToModify = replaceSubtreeWithNewBindMacro( - astToModify, bindIdentifier, cseCandidate.expr(), lca.expr(), lca.id()); + astToModify, bindIdentifier, cseCandidate, lca.expr(), lca.id()); // Retain the existing macro calls in case if the bind identifiers are replacing a subtree // that contains a comprehension. @@ -224,8 +224,8 @@ private Optional findCseCandidate(CelAbstractSyntaxTree ast) { .collect(toImmutableList()); for (CelNavigableExpr node : allNodes) { - // Strip out all IDs to test equivalence - CelExpr celExpr = clearExprIds(node.expr()); + // Normalize the expr to test semantic equivalence. + CelExpr celExpr = normalizeForEquality(node.expr()); if (encounteredNodes.contains(celExpr)) { return Optional.of(node); } @@ -240,6 +240,7 @@ private static boolean canEliminate(CelNavigableExpr navigableExpr) { return !navigableExpr.getKind().equals(Kind.CONSTANT) && !navigableExpr.getKind().equals(Kind.IDENT) && !navigableExpr.expr().identOrDefault().name().startsWith(BIND_IDENTIFIER_PREFIX) + && !navigableExpr.expr().selectOrDefault().testOnly() && isAllowedFunction(navigableExpr) && isWithinInlineableComprehension(navigableExpr); } @@ -271,7 +272,7 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) { } private boolean areSemanticallyEqual(CelExpr expr1, CelExpr expr2) { - return clearExprIds(expr1).equals(clearExprIds(expr2)); + return normalizeForEquality(expr1).equals(normalizeForEquality(expr2)); } private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) { @@ -282,6 +283,47 @@ private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) { return true; } + /** + * Converts the {@link CelExpr} to make it suitable for performing semantically equals check in + * {@link #areSemanticallyEqual(CelExpr, CelExpr)}. + * + *

Specifically, this will: + * + *

+ */ + private CelExpr normalizeForEquality(CelExpr celExpr) { + int iterCount; + for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) { + CelExpr presenceTestExpr = + CelNavigableExpr.fromExpr(celExpr) + .allNodes() + .map(CelNavigableExpr::expr) + .filter(expr -> expr.selectOrDefault().testOnly()) + .findAny() + .orElse(null); + if (presenceTestExpr == null) { + break; + } + + CelExpr newExpr = + presenceTestExpr.toBuilder() + .setSelect(presenceTestExpr.select().toBuilder().setTestOnly(false).build()) + .build(); + + celExpr = replaceSubtree(celExpr, newExpr, newExpr.id()); + } + + if (iterCount >= cseOptions.maxIterationLimit()) { + throw new IllegalStateException("Max iteration count reached."); + } + + return clearExprIds(celExpr); + } + /** Options to configure how Common Subexpression Elimination behave. */ @AutoValue public abstract static class SubexpressionOptimizerOptions { diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 79e04fc6..98e1ff75 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -31,6 +31,7 @@ import dev.cel.common.CelOptions; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; +import dev.cel.common.types.OptionalType; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; import dev.cel.extensions.CelExtensions; @@ -75,7 +76,8 @@ public class SubexpressionOptimizerTest { .setSingleInt64(10L) .putMapInt32Int64(0, 1) .putMapInt32Int64(1, 5) - .putMapInt32Int64(2, 2))) + .putMapInt32Int64(2, 2) + .putMapStringString("key", "A"))) .build(); private static CelBuilder newCelBuilder() { @@ -92,6 +94,7 @@ private static CelBuilder newCelBuilder() { "custom_func", newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT))) .addVar("x", SimpleType.DYN) + .addVar("opt_x", OptionalType.create(SimpleType.DYN)) .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); } @@ -314,6 +317,13 @@ private enum CseTestCase { "[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]", "cel.bind(@r0, [1, 2, 3], @r0.map(@c0, @r0.map(@c1, @c1 + 1))) == " + "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])"), + INCLUSION_LIST( + "1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]", + "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, 1 in @r0, @r1 && 2 in @r0 && 3 in [3, @r0] &&" + + " @r1))"), + INCLUSION_MAP( + "2 in {'a': 1, 2: {true: false}, 3: {true: false}}", + "2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})"), MACRO_SHADOWED_VARIABLE( "[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3", "cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0, @c0 - 1 > 3) ||" @@ -322,6 +332,86 @@ private enum CseTestCase { "size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2", "size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))" + ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2"), + PRESENCE_TEST( + "has({'a': true}.a) && {'a':true}['a']", + "cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])"), + PRESENCE_TEST_WITH_TERNARY( + "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 : 0) == 10", + "cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10"), + PRESENCE_TEST_WITH_TERNARY_2( + "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 :" + + " msg.oneof_type.payload.single_int64 * 0) == 10", + "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload.single_int64, has(@r0.payload) ?" + + " @r1 : (@r1 * 0))) == 10"), + PRESENCE_TEST_WITH_TERNARY_3( + "(has(msg.oneof_type.payload.single_int64) ? msg.oneof_type.payload.single_int64 :" + + " msg.oneof_type.payload.single_int64 * 0) == 10", + "cel.bind(@r0, msg.oneof_type.payload, cel.bind(@r1, @r0.single_int64," + + " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10"), + /** + * Input: + * + *
{@code
+     * (
+     *   has(msg.oneof_type) &&
+     *   has(msg.oneof_type.payload) &&
+     *   has(msg.oneof_type.payload.single_int64)
+     * ) ?
+     *   (
+     *     (
+     *       has(msg.oneof_type.payload.map_string_string) &&
+     *       has(msg.oneof_type.payload.map_string_string.key)
+     *     ) ?
+     *       msg.oneof_type.payload.map_string_string.key == "A"
+     *     : false
+     *   )
+     * : false
+     * }
+ * + * Unparsed: + * + *
{@code
+     * cel.bind(
+     *   @r0, msg.oneof_type,
+     *   cel.bind(
+     *     @r1, @r0.payload,
+     *     has(msg.oneof_type) && has(@r0.payload) && has(@r1.single_int64) ?
+     *       cel.bind(
+     *         @r2, @r1.map_string_string,
+     *         has(@r1.map_string_string) && has(@r2.key) ? @r2.key == "A" : false,
+     *       )
+     *     : false,
+     *   ),
+     * )
+     * }
+ */ + PRESENCE_TEST_WITH_TERNARY_NESTED( + "(has(msg.oneof_type) && has(msg.oneof_type.payload) &&" + + " has(msg.oneof_type.payload.single_int64)) ?" + + " ((has(msg.oneof_type.payload.map_string_string) &&" + + " has(msg.oneof_type.payload.map_string_string.key)) ?" + + " msg.oneof_type.payload.map_string_string.key == 'A' : false) : false", + "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload, (has(msg.oneof_type) &&" + + " has(@r0.payload) && has(@r1.single_int64)) ? cel.bind(@r2, @r1.map_string_string," + + " (has(@r1.map_string_string) && has(@r2.key)) ? (@r2.key == \"A\") : false) :" + + " false))"), + OPTIONAL_LIST( + "[10, ?optional.none(), [?optional.none(), ?opt_x], [?optional.none(), ?opt_x]] == [10," + + " [5], [5]]", + "cel.bind(@r0, [?optional.none(), ?opt_x], [10, ?optional.none(), @r0, @r0]) ==" + + " cel.bind(@r1, [5], [10, @r1, @r1])"), + OPTIONAL_MAP( + "{?'hello': optional.of('hello')}['hello'] + {?'hello': optional.of('hello')}['hello'] ==" + + " 'hellohello'", + "cel.bind(@r0, {?\"hello\": optional.of(\"hello\")}[\"hello\"], @r0 + @r0) ==" + + " \"hellohello\""), + OPTIONAL_MESSAGE( + "TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + + " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:" + + " optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}.single_int64 == 5", + "cel.bind(@r0, TestAllTypes{" + + "?single_int64: optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}, " + + "@r0.single_int32 + @r0.single_int64) == 5"), ; private final String source; @@ -342,7 +432,9 @@ public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCas assertThat( CEL.createProgram(optimizedAst) - .eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L))) + .eval( + ImmutableMap.of( + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) .isEqualTo(true); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsed); } @@ -366,7 +458,9 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr assertThat( celWithoutMacroMap .createProgram(optimizedAst) - .eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L))) + .eval( + ImmutableMap.of( + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) .isEqualTo(true); } @@ -384,7 +478,8 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr @TestParameters("{source: 'custom_func(1) + custom_func(1)'}") // Duplicated but nested calls. @TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}") - // Ternary with presence test is not supported yet. + // This cannot be optimized. Extracting the common subexpression would presence test + // the bound identifier (e.g: has(@r0)), which is not valid. @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") public void cse_noop(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst();