From 90671c020c7b1d3c4ac408ae18d3ed2404f8e6fa Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 21 Feb 2024 12:59:35 -0800 Subject: [PATCH] Mangle identifier name for comprehension result PiperOrigin-RevId: 609094679 --- .../main/java/dev/cel/common/ast/CelExpr.java | 2 + .../java/dev/cel/optimizer/MutableAst.java | 183 +++++++++++++----- .../optimizers/SubexpressionOptimizer.java | 19 +- .../dev/cel/optimizer/MutableAstTest.java | 31 ++- 4 files changed, 169 insertions(+), 66 deletions(-) diff --git a/common/src/main/java/dev/cel/common/ast/CelExpr.java b/common/src/main/java/dev/cel/common/ast/CelExpr.java index c106a1ff..c4e53906 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExpr.java +++ b/common/src/main/java/dev/cel/common/ast/CelExpr.java @@ -1011,6 +1011,8 @@ public abstract static class CelComprehension { /** Builder for Comprehension. */ @AutoValue.Builder public abstract static class Builder { + public abstract String accuVar(); + public abstract CelExpr iterRange(); public abstract CelExpr accuInit(); diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 1e67a4c7..9333726a 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -29,6 +29,7 @@ import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelSource; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelComprehension; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.ast.CelExprFactory; @@ -45,6 +46,7 @@ import java.util.Map.Entry; import java.util.NoSuchElementException; import java.util.Optional; +import java.util.function.Predicate; import java.util.stream.Collectors; /** MutableAst contains logic for mutating a {@link CelAbstractSyntaxTree}. */ @@ -208,20 +210,27 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) * * * @param ast AST to mutate - * @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will - * produce @c0:0, @c0:1, @c1:0, @c2:0... as new names. + * @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example, + * providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names. + * @param newResultPrefix Prefix to use for new comprehensin result identifier names. */ public MangledComprehensionAst mangleComprehensionIdentifierNames( - CelAbstractSyntaxTree ast, String newIdentPrefix) { + CelAbstractSyntaxTree ast, String newIterVarPrefix, String newResultPrefix) { CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast); - LinkedHashMap comprehensionsToMangle = + Predicate comprehensionIdentifierPredicate = x -> true; + comprehensionIdentifierPredicate = + comprehensionIdentifierPredicate + .and(node -> node.getKind().equals(Kind.COMPREHENSION)) + .and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix)) + .and(node -> !node.expr().comprehension().accuVar().startsWith(newResultPrefix)); + + LinkedHashMap comprehensionsToMangle = newNavigableAst .getRoot() // This is important - mangling needs to happen bottom-up to avoid stepping over // shadowed variables that are not part of the comprehension being mangled. .allNodes(TraversalOrder.POST_ORDER) - .filter(node -> node.getKind().equals(Kind.COMPREHENSION)) - .filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix)) + .filter(comprehensionIdentifierPredicate) .filter( node -> { // Ensure the iter_var is actually referenced in the loop_step. If it's not, we @@ -236,9 +245,10 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( Collectors.toMap( k -> k, v -> { - String iterVar = v.expr().comprehension().iterVar(); + CelComprehension comprehension = v.expr().comprehension(); + String iterVar = comprehension.iterVar(); long iterVarId = - CelNavigableExpr.fromExpr(v.expr().comprehension().loopStep()) + CelNavigableExpr.fromExpr(comprehension.loopStep()) .allNodes() .filter( loopStepNode -> @@ -252,11 +262,22 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( + v.id()); }); - return ast.getType(iterVarId) - .orElseThrow( - () -> - new NoSuchElementException( - "Checked type not present for: " + iterVarId)); + CelType iterVarType = + ast.getType(iterVarId) + .orElseThrow( + () -> + new NoSuchElementException( + "Checked type not present for iteration variable: " + + iterVarId)); + CelType resultType = + ast.getType(comprehension.result().id()) + .orElseThrow( + () -> + new NoSuchElementException( + "Checked type not present for result: " + + comprehension.result().id())); + + return MangledComprehensionType.of(iterVarType, resultType); }, (x, y) -> { throw new IllegalStateException("Unexpected CelNavigableExpr collision"); @@ -265,53 +286,62 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( int iterCount = 0; // The map that we'll eventually return to the caller. - HashMap mangledIdentNamesToType = new HashMap<>(); + HashMap mangledIdentNamesToType = + new HashMap<>(); // Intermediary table used for the purposes of generating a unique mangled variable name. - Table comprehensionLevelToType = HashBasedTable.create(); - for (Entry comprehensionEntry : comprehensionsToMangle.entrySet()) { + Table comprehensionLevelToType = + HashBasedTable.create(); + for (Entry comprehensionEntry : + comprehensionsToMangle.entrySet()) { iterCount++; // Refetch the comprehension node as mutating the AST could have renumbered its IDs. CelNavigableExpr comprehensionNode = newNavigableAst .getRoot() .allNodes(TraversalOrder.POST_ORDER) - .filter(node -> node.getKind().equals(Kind.COMPREHENSION)) - .filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix)) + .filter(comprehensionIdentifierPredicate) .findAny() .orElseThrow( () -> new NoSuchElementException("Failed to refetch mutated comprehension")); - CelType comprehensionEntryType = comprehensionEntry.getValue(); + MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue(); CelExpr.Builder comprehensionExpr = comprehensionNode.expr().toBuilder(); - String iterVar = comprehensionExpr.comprehension().iterVar(); int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode); - String mangledVarName; + MangledComprehensionName mangledComprehensionName; if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) { - mangledVarName = + mangledComprehensionName = comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType); } else { // First time encountering the pair of . Generate a unique // mangled variable name for this. int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size(); - mangledVarName = newIdentPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx; + String mangledIterVarName = + newIterVarPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx; + String mangledResultName = + newResultPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx; + mangledComprehensionName = + MangledComprehensionName.of(mangledIterVarName, mangledResultName); comprehensionLevelToType.put( - comprehensionNestingLevel, comprehensionEntryType, mangledVarName); + comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName); } - mangledIdentNamesToType.put(mangledVarName, comprehensionEntryType); + mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType); + String iterVar = comprehensionExpr.comprehension().iterVar(); + String accuVar = comprehensionExpr.comprehension().accuVar(); CelExpr.Builder mutatedComprehensionExpr = mangleIdentsInComprehensionExpr( newNavigableAst.getAst().getExpr().toBuilder(), comprehensionExpr, iterVar, - mangledVarName); + accuVar, + mangledComprehensionName); // Repeat the mangling process for the macro source. CelSource newSource = mangleIdentsInMacroSource( newNavigableAst.getAst(), mutatedComprehensionExpr, iterVar, - mangledVarName, + mangledComprehensionName, comprehensionExpr.id()); newNavigableAst = @@ -381,14 +411,44 @@ private CelExpr.Builder mangleIdentsInComprehensionExpr( CelExpr.Builder root, CelExpr.Builder comprehensionExpr, String originalIterVar, - String mangledVarName) { + String originalAccuVar, + MangledComprehensionName mangledComprehensionName) { + CelExpr.Builder modifiedLoopStep = + replaceIdentName( + comprehensionExpr.comprehension().loopStep().toBuilder(), + originalIterVar, + mangledComprehensionName.iterVarName()); + comprehensionExpr.setComprehension( + comprehensionExpr.comprehension().toBuilder() + .setLoopStep(modifiedLoopStep.build()) + .build()); + comprehensionExpr = + replaceIdentName(comprehensionExpr, originalAccuVar, mangledComprehensionName.resultName()); + + CelComprehension.Builder newComprehension = + comprehensionExpr.comprehension().toBuilder() + .setIterVar(mangledComprehensionName.iterVarName()); + // Most standard macros set accu_var as __result__, but not all (ex: cel.bind). + if (newComprehension.accuVar().equals(originalAccuVar)) { + newComprehension.setAccuVar(mangledComprehensionName.resultName()); + } + + return mutateExpr( + NO_OP_ID_GENERATOR, + root, + comprehensionExpr.setComprehension(newComprehension.build()), + comprehensionExpr.id()); + } + + private CelExpr.Builder replaceIdentName( + CelExpr.Builder comprehensionExpr, String originalIdentName, String newIdentName) { int iterCount; for (iterCount = 0; iterCount < iterationLimit; iterCount++) { Optional identToMangle = - CelNavigableExpr.fromExpr(comprehensionExpr.comprehension().loopStep()) + CelNavigableExpr.fromExpr(comprehensionExpr.build()) .descendants() .map(CelNavigableExpr::expr) - .filter(node -> node.identOrDefault().name().equals(originalIterVar)) + .filter(node -> node.identOrDefault().name().equals(originalIdentName)) .findAny(); if (!identToMangle.isPresent()) { break; @@ -398,7 +458,7 @@ private CelExpr.Builder mangleIdentsInComprehensionExpr( mutateExpr( NO_OP_ID_GENERATOR, comprehensionExpr, - CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()), + CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(newIdentName).build()), identToMangle.get().id()); } @@ -406,19 +466,14 @@ private CelExpr.Builder mangleIdentsInComprehensionExpr( throw new IllegalStateException("Max iteration count reached."); } - return mutateExpr( - NO_OP_ID_GENERATOR, - root, - comprehensionExpr.setComprehension( - comprehensionExpr.comprehension().toBuilder().setIterVar(mangledVarName).build()), - comprehensionExpr.id()); + return comprehensionExpr; } private CelSource mangleIdentsInMacroSource( CelAbstractSyntaxTree ast, CelExpr.Builder mutatedComprehensionExpr, String originalIterVar, - String mangledVarName, + MangledComprehensionName mangledComprehensionName, long originalComprehensionId) { if (!ast.getSource().getMacroCalls().containsKey(originalComprehensionId)) { return ast.getSource(); @@ -446,7 +501,9 @@ private CelSource mangleIdentsInMacroSource( mutateExpr( NO_OP_ID_GENERATOR, macroExpr, - CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()), + CelExpr.newBuilder() + .setIdent( + CelIdent.newBuilder().setName(mangledComprehensionName.iterVarName()).build()), identToMangle.id()); newSource.addMacroCalls(originalComprehensionId, macroExpr.build()); @@ -652,8 +709,8 @@ private static int countComprehensionNestingLevel(CelNavigableExpr comprehension } /** - * Intermediate value class to store the mangled identifiers for iteration variable in the - * comprehension. + * Intermediate value class to store the mangled identifiers for iteration variable and the + * comprehension result. */ @AutoValue public abstract static class MangledComprehensionAst { @@ -662,11 +719,49 @@ public abstract static class MangledComprehensionAst { public abstract CelAbstractSyntaxTree ast(); /** Map containing the mangled identifier names to their types. */ - public abstract ImmutableMap mangledComprehensionIdents(); + public abstract ImmutableMap + mangledComprehensionMap(); private static MangledComprehensionAst of( - CelAbstractSyntaxTree ast, ImmutableMap mangledComprehensionIdents) { - return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents); + CelAbstractSyntaxTree ast, + ImmutableMap mangledComprehensionMap) { + return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionMap); + } + } + + /** + * Intermediate value class to store the types for iter_var and comprehension result of which its + * identifier names are being mangled. + */ + @AutoValue + public abstract static class MangledComprehensionType { + + /** Type of iter_var */ + public abstract CelType iterVarType(); + + /** Type of comprehension result */ + public abstract CelType resultType(); + + private static MangledComprehensionType of(CelType iterVarType, CelType resultType) { + return new AutoValue_MutableAst_MangledComprehensionType(iterVarType, resultType); + } + } + + /** + * Intermediate value class to store the mangled names for iteration variable and the + * comprehension result. + */ + @AutoValue + public abstract static class MangledComprehensionName { + + /** Mangled name for iter_var */ + public abstract String iterVarName(); + + /** Mangled name for comprehension result */ + public abstract String resultName(); + + private static MangledComprehensionName of(String iterVarName, String resultName) { + return new AutoValue_MutableAst_MangledComprehensionName(iterVarName, resultName); } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 4933b90d..8cf71e75 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -85,6 +85,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer { new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); private static final String BIND_IDENTIFIER_PREFIX = "@r"; private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c"; + private static final String MANGLED_COMPREHENSION_RESULT_PREFIX = "@x"; private static final String CEL_BLOCK_FUNCTION = "cel.@block"; private static final String BLOCK_INDEX_PREFIX = "@index"; private static final ImmutableSet CSE_ALLOWED_FUNCTIONS = @@ -127,7 +128,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( CelType resultType = navigableAst.getAst().getResultType(); MangledComprehensionAst mangledComprehensionAst = mutableAst.mangleComprehensionIdentifierNames( - navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); + navigableAst.getAst(), + MANGLED_COMPREHENSION_IDENTIFIER_PREFIX, + MANGLED_COMPREHENSION_RESULT_PREFIX); CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast(); CelSource sourceToModify = astToModify.getSource(); @@ -191,10 +194,12 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( // Add all mangled comprehension identifiers to the environment, so that the subexpressions can // retain context to them. mangledComprehensionAst - .mangledComprehensionIdents() + .mangledComprehensionMap() .forEach( - (identName, type) -> - celBuilder.addVarDeclarations(CelVarDecl.newVarDeclaration(identName, type))); + (name, type) -> + celBuilder.addVarDeclarations( + CelVarDecl.newVarDeclaration(name.iterVarName(), type.iterVarType()), + CelVarDecl.newVarDeclaration(name.resultName(), type.resultType()))); // Type-check all sub-expressions then add them as block identifiers to the CEL environment addBlockIdentsToEnv(celBuilder, subexpressions); @@ -266,7 +271,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) CelAbstractSyntaxTree astToModify = mutableAst .mangleComprehensionIdentifierNames( - navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX) + navigableAst.getAst(), + MANGLED_COMPREHENSION_IDENTIFIER_PREFIX, + MANGLED_COMPREHENSION_RESULT_PREFIX) .ast(); CelSource sourceToModify = astToModify.getSource(); @@ -432,7 +439,7 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) { // result, loopStep or iterRange. While result is not human authored, it needs to be // included to extract subexpressions that are already in cel.bind macro. CelNavigableExpr.fromExpr(parent.expr().comprehension().result()).descendants(), - CelNavigableExpr.fromExpr(parent.expr().comprehension().loopStep()).descendants(), + CelNavigableExpr.fromExpr(parent.expr().comprehension().loopStep()).allNodes(), CelNavigableExpr.fromExpr(parent.expr().comprehension().iterRange()).allNodes()) .filter( node -> diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index 557302e4..da8ecfb7 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -682,7 +682,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); CelAbstractSyntaxTree mangledAst = - MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast(); + MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c", "@x").ast(); assertThat(mangledAst.getExpr().toString()) .isEqualTo( @@ -695,7 +695,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { + " }\n" + " }\n" + " }\n" - + " accu_var: __result__\n" + + " accu_var: @x0:0\n" + " accu_init: {\n" + " CONSTANT [6] { value: false }\n" + " }\n" @@ -707,7 +707,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { + " function: !_\n" + " args: {\n" + " IDENT [7] {\n" - + " name: __result__\n" + + " name: @x0:0\n" + " }\n" + " }\n" + " }\n" @@ -719,7 +719,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { + " function: _||_\n" + " args: {\n" + " IDENT [10] {\n" - + " name: __result__\n" + + " name: @x0:0\n" + " }\n" + " IDENT [5] {\n" + " name: @c0:0\n" @@ -729,7 +729,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { + " }\n" + " result: {\n" + " IDENT [12] {\n" - + " name: __result__\n" + + " name: @x0:0\n" + " }\n" + " }\n" + "}"); @@ -743,7 +743,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw CelAbstractSyntaxTree ast = CEL.compile("[x].exists(x, [x].exists(x, x == 1))").getAst(); CelAbstractSyntaxTree mangledAst = - MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast(); + MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c", "@x").ast(); assertThat(mangledAst.getExpr().toString()) .isEqualTo( @@ -758,7 +758,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " }\n" + " }\n" + " }\n" - + " accu_var: __result__\n" + + " accu_var: @x0:0\n" + " accu_init: {\n" + " CONSTANT [20] { value: false }\n" + " }\n" @@ -770,7 +770,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " function: !_\n" + " args: {\n" + " IDENT [21] {\n" - + " name: __result__\n" + + " name: @x0:0\n" + " }\n" + " }\n" + " }\n" @@ -782,7 +782,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " function: _||_\n" + " args: {\n" + " IDENT [24] {\n" - + " name: __result__\n" + + " name: @x0:0\n" + " }\n" + " COMPREHENSION [19] {\n" + " iter_var: @c1:0\n" @@ -795,7 +795,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " }\n" + " }\n" + " }\n" - + " accu_var: __result__\n" + + " accu_var: @x1:0\n" + " accu_init: {\n" + " CONSTANT [12] { value: false }\n" + " }\n" @@ -807,7 +807,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " function: !_\n" + " args: {\n" + " IDENT [13] {\n" - + " name: __result__\n" + + " name: @x1:0\n" + " }\n" + " }\n" + " }\n" @@ -819,7 +819,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " function: _||_\n" + " args: {\n" + " IDENT [16] {\n" - + " name: __result__\n" + + " name: @x1:0\n" + " }\n" + " CALL [10] {\n" + " function: _==_\n" @@ -835,7 +835,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " }\n" + " result: {\n" + " IDENT [18] {\n" - + " name: __result__\n" + + " name: @x1:0\n" + " }\n" + " }\n" + " }\n" @@ -844,11 +844,10 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw + " }\n" + " result: {\n" + " IDENT [26] {\n" - + " name: __result__\n" + + " name: @x0:0\n" + " }\n" + " }\n" + "}"); - assertThat(CEL_UNPARSER.unparse(mangledAst)) .isEqualTo("[x].exists(@c0:0, [@c0:0].exists(@c1:0, @c1:0 == 1))"); assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval(ImmutableMap.of("x", 1))) @@ -861,7 +860,7 @@ public void mangleComprehensionVariable_hasMacro_noOp() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("has(msg.single_int64)").getAst(); CelAbstractSyntaxTree mangledAst = - MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast(); + MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c", "@x").ast(); assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("has(msg.single_int64)"); assertThat(