diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index 9a9f22c9..c7f16646 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -70,6 +70,36 @@ default CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( ast, varName, varInit, resultExpr, exprIdToReplace); } + /** + * Replaces all comprehension identifier names with a unique name based on the given prefix. + * + *

The purpose of this is to avoid errors that can be caused by shadowed variables while + * augmenting an AST. As an example: {@code [2, 3].exists(x, x - 1 > 3) || x - 1 > 3}. Note that + * the scoping of `x - 1` is different between th two LOGICAL_OR branches. Iteration variable `x` + * in `exists` will be mangled to {@code [2, 3].exists(@c0, @c0 - 1 > 3) || x - 1 > 3} to avoid + * erroneously extracting x - 1 as common subexpression. + * + *

The expression IDs are not modified when the identifier names are changed. + * + *

Iteration variables in comprehensions are numbered based on their comprehension nesting + * levels. Examples: + * + *

+ * + * @param ast AST to mutate + * @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will + * produce @c0, @c1, @c2... as new names. + */ + default CelAbstractSyntaxTree mangleComprehensionIdentifierNames( + CelAbstractSyntaxTree ast, String newIdentPrefix) { + return MutableAst.mangleComprehensionIdentifierNames(ast, newIdentPrefix); + } + /** Sets all expr IDs in the expression tree to 0. */ default CelExpr clearExprIds(CelExpr celExpr) { return MutableAst.clearExprIds(celExpr); diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index f1f57925..45cbb0ff 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -30,20 +30,26 @@ import dev.cel.common.ast.CelExpr.CelCreateList; import dev.cel.common.ast.CelExpr.CelCreateMap; import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.CelSelect; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.ast.CelExprFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator; import dev.cel.common.ast.CelExprIdGeneratorFactory.StableIdGenerator; +import dev.cel.common.navigation.CelNavigableAst; import dev.cel.common.navigation.CelNavigableExpr; +import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder; import java.util.Map.Entry; import java.util.NoSuchElementException; +import java.util.Optional; /** MutableAst contains logic for mutating a {@link CelExpr}. */ @Internal final class MutableAst { private static final int MAX_ITERATION_COUNT = 1000; + private static final ExprIdGenerator NO_OP_ID_GENERATOR = id -> id; + private final CelExpr.Builder newExpr; private final ExprIdGenerator celExprIdGenerator; private int iterationCount; @@ -160,6 +166,132 @@ static CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) return CelAbstractSyntaxTree.newParsedAst(root.build(), newSource); } + static CelAbstractSyntaxTree mangleComprehensionIdentifierNames( + CelAbstractSyntaxTree ast, String newIdentPrefix) { + int iterCount; + CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast); + for (iterCount = 0; iterCount < MAX_ITERATION_COUNT; iterCount++) { + Optional maybeComprehensionExpr = + newNavigableAst + .getRoot() + // This is important - mangling needs to happen bottom-up to avoid stepping over + // shadowed variables that are not part of the comprehension being mangled. + .allNodes(TraversalOrder.POST_ORDER) + .filter(node -> node.getKind().equals(Kind.COMPREHENSION)) + .filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix)) + .findAny(); + if (!maybeComprehensionExpr.isPresent()) { + break; + } + + CelExpr.Builder comprehensionExpr = maybeComprehensionExpr.get().expr().toBuilder(); + String iterVar = comprehensionExpr.comprehension().iterVar(); + int comprehensionNestingLevel = countComprehensionNestingLevel(maybeComprehensionExpr.get()); + String mangledVarName = newIdentPrefix + comprehensionNestingLevel; + + CelExpr.Builder mutatedComprehensionExpr = + mangleIdentsInComprehensionExpr( + newNavigableAst.getAst().getExpr().toBuilder(), + comprehensionExpr, + iterVar, + mangledVarName); + // Repeat the mangling process for the macro source. + CelSource newSource = + mangleIdentsInMacroSource( + newNavigableAst.getAst(), + mutatedComprehensionExpr, + iterVar, + mangledVarName, + comprehensionExpr.id()); + + newNavigableAst = + CelNavigableAst.fromAst( + CelAbstractSyntaxTree.newParsedAst(mutatedComprehensionExpr.build(), newSource)); + } + + if (iterCount >= MAX_ITERATION_COUNT) { + throw new IllegalStateException("Max iteration count reached."); + } + + return newNavigableAst.getAst(); + } + + private static CelExpr.Builder mangleIdentsInComprehensionExpr( + CelExpr.Builder root, + CelExpr.Builder comprehensionExpr, + String originalIterVar, + String mangledVarName) { + int iterCount; + for (iterCount = 0; iterCount < MAX_ITERATION_COUNT; iterCount++) { + Optional identToMangle = + CelNavigableExpr.fromExpr(comprehensionExpr.comprehension().loopStep()) + .descendants() + .map(CelNavigableExpr::expr) + .filter(node -> node.identOrDefault().name().equals(originalIterVar)) + .findAny(); + if (!identToMangle.isPresent()) { + break; + } + + comprehensionExpr = + replaceSubtreeImpl( + NO_OP_ID_GENERATOR, + comprehensionExpr, + CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()), + identToMangle.get().id()); + } + + if (iterCount >= MAX_ITERATION_COUNT) { + throw new IllegalStateException("Max iteration count reached."); + } + + return replaceSubtreeImpl( + NO_OP_ID_GENERATOR, + root, + comprehensionExpr.setComprehension( + comprehensionExpr.comprehension().toBuilder().setIterVar(mangledVarName).build()), + comprehensionExpr.id()); + } + + private static CelSource mangleIdentsInMacroSource( + CelAbstractSyntaxTree ast, + CelExpr.Builder mutatedComprehensionExpr, + String originalIterVar, + String mangledVarName, + long originalComprehensionId) { + if (!ast.getSource().getMacroCalls().containsKey(originalComprehensionId)) { + return ast.getSource(); + } + + // First, normalize the macro source. + // ex: [x].exists(x, [x].exists(x, x == 1)) -> [x].exists(x, [@c1].exists(x, @c0 == 1)). + CelSource.Builder newSource = + normalizeMacroSource(ast.getSource(), -1, mutatedComprehensionExpr, (id) -> id).toBuilder(); + + // Note that in the above example, the iteration variable is not replaced after normalization. + // This is because populating a macro call map upon parse generates a new unique identifier + // that does not exist in the main AST. Thus, we need to manually replace the identifier. + CelExpr.Builder macroExpr = newSource.getMacroCalls().get(originalComprehensionId).toBuilder(); + // By convention, the iteration variable is always the first argument of the + // macro call expression. + CelExpr identToMangle = macroExpr.call().args().get(0); + if (!identToMangle.identOrDefault().name().equals(originalIterVar)) { + throw new IllegalStateException( + String.format( + "Expected %s for iteration variable but got %s instead.", + identToMangle.identOrDefault().name(), originalIterVar)); + } + macroExpr = + replaceSubtreeImpl( + NO_OP_ID_GENERATOR, + macroExpr, + CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()), + identToMangle.id()); + + newSource.addMacroCalls(originalComprehensionId, macroExpr.build()); + return newSource.build(); + } + private static BindMacro newBindMacro( String varName, CelExpr varInit, CelExpr resultExpr, StableIdGenerator stableIdGenerator) { // Renumber incoming expression IDs in the init and result expression to avoid collision with @@ -344,6 +476,19 @@ private static long getMaxId(CelExpr newExpr) { .orElseThrow(NoSuchElementException::new); } + private static int countComprehensionNestingLevel(CelNavigableExpr comprehensionExpr) { + int nestedLevel = 0; + Optional maybeParent = comprehensionExpr.parent(); + while (maybeParent.isPresent()) { + if (maybeParent.get().getKind().equals(Kind.COMPREHENSION)) { + nestedLevel++; + } + + maybeParent = maybeParent.get().parent(); + } + return nestedLevel; + } + private CelExpr.Builder visit(CelExpr.Builder expr) { if (++iterationCount > MAX_ITERATION_COUNT) { throw new IllegalStateException("Max iteration count reached."); diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 88ae694e..d2f1ae7e 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -63,6 +63,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer { private static final SubexpressionOptimizer INSTANCE = new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); private static final String BIND_IDENTIFIER_PREFIX = "@r"; + private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c"; private static final ImmutableSet CSE_ALLOWED_FUNCTIONS = Streams.concat( stream(Operator.values()).map(Operator::getFunction), @@ -88,8 +89,11 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c @Override public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { - CelAbstractSyntaxTree astToModify = navigableAst.getAst(); + CelAbstractSyntaxTree astToModify = + mangleComprehensionIdentifierNames( + navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); CelSource sourceToModify = astToModify.getSource(); + int bindIdentifierIndex = 0; int iterCount; for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) { @@ -247,9 +251,10 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) { if (parent.getKind().equals(Kind.COMPREHENSION)) { return Streams.concat( // If the expression is within a comprehension, it is eligible for CSE iff is in - // result or iterRange. While result is not human authored, it needs to be included - // to extract subexpressions that are already in cel.bind macro. + // result, loopStep or iterRange. While result is not human authored, it needs to be + // included to extract subexpressions that are already in cel.bind macro. CelNavigableExpr.fromExpr(parent.expr().comprehension().result()).descendants(), + CelNavigableExpr.fromExpr(parent.expr().comprehension().loopStep()).descendants(), CelNavigableExpr.fromExpr(parent.expr().comprehension().iterRange()).allNodes()) .filter( node -> diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index 587c3c71..8f79838c 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -652,6 +652,197 @@ public void comprehension_replaceLoopStep() throws Exception { assertConsistentMacroCalls(ast); } + @Test + public void mangleComprehensionVariable_singleMacro() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); + + CelAbstractSyntaxTree mangledAst = MutableAst.mangleComprehensionIdentifierNames(ast, "@c"); + + assertThat(mangledAst.getExpr().toString()) + .isEqualTo( + "COMPREHENSION [13] {\n" + + " iter_var: @c0\n" + + " iter_range: {\n" + + " CREATE_LIST [1] {\n" + + " elements: {\n" + + " CONSTANT [2] { value: false }\n" + + " }\n" + + " }\n" + + " }\n" + + " accu_var: __result__\n" + + " accu_init: {\n" + + " CONSTANT [6] { value: false }\n" + + " }\n" + + " loop_condition: {\n" + + " CALL [9] {\n" + + " function: @not_strictly_false\n" + + " args: {\n" + + " CALL [8] {\n" + + " function: !_\n" + + " args: {\n" + + " IDENT [7] {\n" + + " name: __result__\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " loop_step: {\n" + + " CALL [11] {\n" + + " function: _||_\n" + + " args: {\n" + + " IDENT [10] {\n" + + " name: __result__\n" + + " }\n" + + " IDENT [5] {\n" + + " name: @c0\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " result: {\n" + + " IDENT [12] {\n" + + " name: __result__\n" + + " }\n" + + " }\n" + + "}"); + assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c0, @c0)"); + assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval()).isEqualTo(false); + assertConsistentMacroCalls(ast); + } + + @Test + public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[x].exists(x, [x].exists(x, x == 1))").getAst(); + + CelAbstractSyntaxTree mangledAst = MutableAst.mangleComprehensionIdentifierNames(ast, "@c"); + + assertThat(mangledAst.getExpr().toString()) + .isEqualTo( + "COMPREHENSION [27] {\n" + + " iter_var: @c0\n" + + " iter_range: {\n" + + " CREATE_LIST [1] {\n" + + " elements: {\n" + + " IDENT [2] {\n" + + " name: x\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " accu_var: __result__\n" + + " accu_init: {\n" + + " CONSTANT [20] { value: false }\n" + + " }\n" + + " loop_condition: {\n" + + " CALL [23] {\n" + + " function: @not_strictly_false\n" + + " args: {\n" + + " CALL [22] {\n" + + " function: !_\n" + + " args: {\n" + + " IDENT [21] {\n" + + " name: __result__\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " loop_step: {\n" + + " CALL [25] {\n" + + " function: _||_\n" + + " args: {\n" + + " IDENT [24] {\n" + + " name: __result__\n" + + " }\n" + + " COMPREHENSION [19] {\n" + + " iter_var: @c1\n" + + " iter_range: {\n" + + " CREATE_LIST [5] {\n" + + " elements: {\n" + + " IDENT [6] {\n" + + " name: @c0\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " accu_var: __result__\n" + + " accu_init: {\n" + + " CONSTANT [12] { value: false }\n" + + " }\n" + + " loop_condition: {\n" + + " CALL [15] {\n" + + " function: @not_strictly_false\n" + + " args: {\n" + + " CALL [14] {\n" + + " function: !_\n" + + " args: {\n" + + " IDENT [13] {\n" + + " name: __result__\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " loop_step: {\n" + + " CALL [17] {\n" + + " function: _||_\n" + + " args: {\n" + + " IDENT [16] {\n" + + " name: __result__\n" + + " }\n" + + " CALL [10] {\n" + + " function: _==_\n" + + " args: {\n" + + " IDENT [9] {\n" + + " name: @c1\n" + + " }\n" + + " CONSTANT [11] { value: 1 }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " result: {\n" + + " IDENT [18] {\n" + + " name: __result__\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " result: {\n" + + " IDENT [26] {\n" + + " name: __result__\n" + + " }\n" + + " }\n" + + "}"); + + assertThat(CEL_UNPARSER.unparse(mangledAst)) + .isEqualTo("[x].exists(@c0, [@c0].exists(@c1, @c1 == 1))"); + assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval(ImmutableMap.of("x", 1))) + .isEqualTo(true); + assertConsistentMacroCalls(ast); + } + + @Test + public void mangleComprehensionVariable_hasMacro_noOp() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("has(msg.single_int64)").getAst(); + + CelAbstractSyntaxTree mangledAst = MutableAst.mangleComprehensionIdentifierNames(ast, "@c"); + + assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("has(msg.single_int64)"); + assertThat( + CEL.createProgram(CEL.check(mangledAst).getAst()) + .eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance()))) + .isEqualTo(false); + assertConsistentMacroCalls(ast); + } + /** * Asserts that the expressions that appears in source_info's macro calls are consistent with the * actual expr nodes in the AST. diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 15741750..79e04fc6 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -304,20 +304,25 @@ private enum CseTestCase { "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? " + "cel.bind(@r1, msg.single_int32, (@r1 > 0) ? (@r0 + @r1) : 0) : 0) == 8"), MULTIPLE_MACROS( - "size([[1].exists(x, x > 0)]) + size([[1].exists(x, x > 0)]) + " - + "size([[2].exists(x, x > 1)]) + size([[2].exists(x, x > 1)]) == 4", - "cel.bind(@r1, size([[2].exists(x, x > 1)]), " - + "cel.bind(@r0, size([[1].exists(x, x > 0)]), @r0 + @r0) + @r1 + @r1) == 4"), + // Note that all of these have different iteration variables, but they are still logically + // the same. + "size([[1].exists(i, i > 0)]) + size([[1].exists(j, j > 0)]) + " + + "size([[2].exists(k, k > 1)]) + size([[2].exists(l, l > 1)]) == 4", + "cel.bind(@r1, size([[2].exists(@c0, @c0 > 1)]), " + + "cel.bind(@r0, size([[1].exists(@c0, @c0 > 0)]), @r0 + @r0) + @r1 + @r1) == 4"), NESTED_MACROS( "[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]", - "cel.bind(@r0, [1, 2, 3], @r0.map(i, @r0.map(i, i + 1))) == cel.bind(@r1, [2, 3, 4], [@r1," - + " @r1, @r1])"), + "cel.bind(@r0, [1, 2, 3], @r0.map(@c0, @r0.map(@c1, @c1 + 1))) == " + + "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])"), MACRO_SHADOWED_VARIABLE( - // Macro variable `x` in .exists is shadowed. - // This is left intact due to the fact that loop condition is not optimized at the moment. "[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3", - "cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(x, x - 1 > 3) ||" - + " @r1))"); + "cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0, @c0 - 1 > 3) ||" + + " @r1))"), + MACRO_SHADOWED_VARIABLE_2( + "size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2", + "size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))" + + ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2"), + ; private final String source; private final String unparsed; @@ -379,8 +384,6 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr @TestParameters("{source: 'custom_func(1) + custom_func(1)'}") // Duplicated but nested calls. @TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}") - // Loop condition is not optimized at the moment. This requires mangling. - @TestParameters("{source: '[\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])'}") // Ternary with presence test is not supported yet. @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") public void cse_noop(String source) throws Exception { @@ -502,9 +505,9 @@ public void cse_largeNestedMacro() throws Exception { assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo( - "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, size(@r0.map(i, @r0.map(i, @r0.map(i," - + " @r0.map(i, @r0.map(i, @r0.map(i, @r0.map(i, @r0.map(i, [1, 2, 3]))))))))), @r1" - + " + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1))"); + "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, size(@r0.map(@c0, @r0.map(@c1, @r0.map(@c2, " + + "@r0.map(@c3, @r0.map(@c4, @r0.map(@c5, @r0.map(@c6, @r0.map(@c7, @r0))))))))), " + + "@r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1))"); assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(27); }