Skip to content

Commit

Permalink
Mangle identifier name for comprehension result
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609094679
  • Loading branch information
l46kok authored and copybara-github committed Feb 21, 2024
1 parent ad2c6b6 commit 90671c0
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 66 deletions.
2 changes: 2 additions & 0 deletions common/src/main/java/dev/cel/common/ast/CelExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,8 @@ public abstract static class CelComprehension {
/** Builder for Comprehension. */
@AutoValue.Builder
public abstract static class Builder {
public abstract String accuVar();

public abstract CelExpr iterRange();

public abstract CelExpr accuInit();
Expand Down
183 changes: 139 additions & 44 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelSource;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelComprehension;
import dev.cel.common.ast.CelExpr.CelIdent;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.ast.CelExprFactory;
Expand All @@ -45,6 +46,7 @@
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/** MutableAst contains logic for mutating a {@link CelAbstractSyntaxTree}. */
Expand Down Expand Up @@ -208,20 +210,27 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast)
* </ul>
*
* @param ast AST to mutate
* @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will
* produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
* @param newIterVarPrefix Prefix to use for new iteration variable identifier name. For example,
* providing @c will produce @c0:0, @c0:1, @c1:0, @c2:0... as new names.
* @param newResultPrefix Prefix to use for new comprehensin result identifier names.
*/
public MangledComprehensionAst mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
CelAbstractSyntaxTree ast, String newIterVarPrefix, String newResultPrefix) {
CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast);
LinkedHashMap<CelNavigableExpr, CelType> comprehensionsToMangle =
Predicate<CelNavigableExpr> comprehensionIdentifierPredicate = x -> true;
comprehensionIdentifierPredicate =
comprehensionIdentifierPredicate
.and(node -> node.getKind().equals(Kind.COMPREHENSION))
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix))
.and(node -> !node.expr().comprehension().accuVar().startsWith(newResultPrefix));

LinkedHashMap<CelNavigableExpr, MangledComprehensionType> comprehensionsToMangle =
newNavigableAst
.getRoot()
// This is important - mangling needs to happen bottom-up to avoid stepping over
// shadowed variables that are not part of the comprehension being mangled.
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.filter(comprehensionIdentifierPredicate)
.filter(
node -> {
// Ensure the iter_var is actually referenced in the loop_step. If it's not, we
Expand All @@ -236,9 +245,10 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
Collectors.toMap(
k -> k,
v -> {
String iterVar = v.expr().comprehension().iterVar();
CelComprehension comprehension = v.expr().comprehension();
String iterVar = comprehension.iterVar();
long iterVarId =
CelNavigableExpr.fromExpr(v.expr().comprehension().loopStep())
CelNavigableExpr.fromExpr(comprehension.loopStep())
.allNodes()
.filter(
loopStepNode ->
Expand All @@ -252,11 +262,22 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
+ v.id());
});

return ast.getType(iterVarId)
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for: " + iterVarId));
CelType iterVarType =
ast.getType(iterVarId)
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for iteration variable: "
+ iterVarId));
CelType resultType =
ast.getType(comprehension.result().id())
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for result: "
+ comprehension.result().id()));

return MangledComprehensionType.of(iterVarType, resultType);
},
(x, y) -> {
throw new IllegalStateException("Unexpected CelNavigableExpr collision");
Expand All @@ -265,53 +286,62 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
int iterCount = 0;

// The map that we'll eventually return to the caller.
HashMap<String, CelType> mangledIdentNamesToType = new HashMap<>();
HashMap<MangledComprehensionName, MangledComprehensionType> mangledIdentNamesToType =
new HashMap<>();
// Intermediary table used for the purposes of generating a unique mangled variable name.
Table<Integer, CelType, String> comprehensionLevelToType = HashBasedTable.create();
for (Entry<CelNavigableExpr, CelType> comprehensionEntry : comprehensionsToMangle.entrySet()) {
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType =
HashBasedTable.create();
for (Entry<CelNavigableExpr, MangledComprehensionType> comprehensionEntry :
comprehensionsToMangle.entrySet()) {
iterCount++;
// Refetch the comprehension node as mutating the AST could have renumbered its IDs.
CelNavigableExpr comprehensionNode =
newNavigableAst
.getRoot()
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.filter(comprehensionIdentifierPredicate)
.findAny()
.orElseThrow(
() -> new NoSuchElementException("Failed to refetch mutated comprehension"));
CelType comprehensionEntryType = comprehensionEntry.getValue();
MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue();

CelExpr.Builder comprehensionExpr = comprehensionNode.expr().toBuilder();
String iterVar = comprehensionExpr.comprehension().iterVar();
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
String mangledVarName;
MangledComprehensionName mangledComprehensionName;
if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) {
mangledVarName =
mangledComprehensionName =
comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType);
} else {
// First time encountering the pair of <ComprehensionLevel, CelType>. Generate a unique
// mangled variable name for this.
int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size();
mangledVarName = newIdentPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
String mangledIterVarName =
newIterVarPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
String mangledResultName =
newResultPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
mangledComprehensionName =
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
comprehensionLevelToType.put(
comprehensionNestingLevel, comprehensionEntryType, mangledVarName);
comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName);
}
mangledIdentNamesToType.put(mangledVarName, comprehensionEntryType);
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType);

String iterVar = comprehensionExpr.comprehension().iterVar();
String accuVar = comprehensionExpr.comprehension().accuVar();
CelExpr.Builder mutatedComprehensionExpr =
mangleIdentsInComprehensionExpr(
newNavigableAst.getAst().getExpr().toBuilder(),
comprehensionExpr,
iterVar,
mangledVarName);
accuVar,
mangledComprehensionName);
// Repeat the mangling process for the macro source.
CelSource newSource =
mangleIdentsInMacroSource(
newNavigableAst.getAst(),
mutatedComprehensionExpr,
iterVar,
mangledVarName,
mangledComprehensionName,
comprehensionExpr.id());

newNavigableAst =
Expand Down Expand Up @@ -381,14 +411,44 @@ private CelExpr.Builder mangleIdentsInComprehensionExpr(
CelExpr.Builder root,
CelExpr.Builder comprehensionExpr,
String originalIterVar,
String mangledVarName) {
String originalAccuVar,
MangledComprehensionName mangledComprehensionName) {
CelExpr.Builder modifiedLoopStep =
replaceIdentName(
comprehensionExpr.comprehension().loopStep().toBuilder(),
originalIterVar,
mangledComprehensionName.iterVarName());
comprehensionExpr.setComprehension(
comprehensionExpr.comprehension().toBuilder()
.setLoopStep(modifiedLoopStep.build())
.build());
comprehensionExpr =
replaceIdentName(comprehensionExpr, originalAccuVar, mangledComprehensionName.resultName());

CelComprehension.Builder newComprehension =
comprehensionExpr.comprehension().toBuilder()
.setIterVar(mangledComprehensionName.iterVarName());
// Most standard macros set accu_var as __result__, but not all (ex: cel.bind).
if (newComprehension.accuVar().equals(originalAccuVar)) {
newComprehension.setAccuVar(mangledComprehensionName.resultName());
}

return mutateExpr(
NO_OP_ID_GENERATOR,
root,
comprehensionExpr.setComprehension(newComprehension.build()),
comprehensionExpr.id());
}

private CelExpr.Builder replaceIdentName(
CelExpr.Builder comprehensionExpr, String originalIdentName, String newIdentName) {
int iterCount;
for (iterCount = 0; iterCount < iterationLimit; iterCount++) {
Optional<CelExpr> identToMangle =
CelNavigableExpr.fromExpr(comprehensionExpr.comprehension().loopStep())
CelNavigableExpr.fromExpr(comprehensionExpr.build())
.descendants()
.map(CelNavigableExpr::expr)
.filter(node -> node.identOrDefault().name().equals(originalIterVar))
.filter(node -> node.identOrDefault().name().equals(originalIdentName))
.findAny();
if (!identToMangle.isPresent()) {
break;
Expand All @@ -398,27 +458,22 @@ private CelExpr.Builder mangleIdentsInComprehensionExpr(
mutateExpr(
NO_OP_ID_GENERATOR,
comprehensionExpr,
CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()),
CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(newIdentName).build()),
identToMangle.get().id());
}

if (iterCount >= iterationLimit) {
throw new IllegalStateException("Max iteration count reached.");
}

return mutateExpr(
NO_OP_ID_GENERATOR,
root,
comprehensionExpr.setComprehension(
comprehensionExpr.comprehension().toBuilder().setIterVar(mangledVarName).build()),
comprehensionExpr.id());
return comprehensionExpr;
}

private CelSource mangleIdentsInMacroSource(
CelAbstractSyntaxTree ast,
CelExpr.Builder mutatedComprehensionExpr,
String originalIterVar,
String mangledVarName,
MangledComprehensionName mangledComprehensionName,
long originalComprehensionId) {
if (!ast.getSource().getMacroCalls().containsKey(originalComprehensionId)) {
return ast.getSource();
Expand Down Expand Up @@ -446,7 +501,9 @@ private CelSource mangleIdentsInMacroSource(
mutateExpr(
NO_OP_ID_GENERATOR,
macroExpr,
CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()),
CelExpr.newBuilder()
.setIdent(
CelIdent.newBuilder().setName(mangledComprehensionName.iterVarName()).build()),
identToMangle.id());

newSource.addMacroCalls(originalComprehensionId, macroExpr.build());
Expand Down Expand Up @@ -652,8 +709,8 @@ private static int countComprehensionNestingLevel(CelNavigableExpr comprehension
}

/**
* Intermediate value class to store the mangled identifiers for iteration variable in the
* comprehension.
* Intermediate value class to store the mangled identifiers for iteration variable and the
* comprehension result.
*/
@AutoValue
public abstract static class MangledComprehensionAst {
Expand All @@ -662,11 +719,49 @@ public abstract static class MangledComprehensionAst {
public abstract CelAbstractSyntaxTree ast();

/** Map containing the mangled identifier names to their types. */
public abstract ImmutableMap<String, CelType> mangledComprehensionIdents();
public abstract ImmutableMap<MangledComprehensionName, MangledComprehensionType>
mangledComprehensionMap();

private static MangledComprehensionAst of(
CelAbstractSyntaxTree ast, ImmutableMap<String, CelType> mangledComprehensionIdents) {
return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents);
CelAbstractSyntaxTree ast,
ImmutableMap<MangledComprehensionName, MangledComprehensionType> mangledComprehensionMap) {
return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionMap);
}
}

/**
* Intermediate value class to store the types for iter_var and comprehension result of which its
* identifier names are being mangled.
*/
@AutoValue
public abstract static class MangledComprehensionType {

/** Type of iter_var */
public abstract CelType iterVarType();

/** Type of comprehension result */
public abstract CelType resultType();

private static MangledComprehensionType of(CelType iterVarType, CelType resultType) {
return new AutoValue_MutableAst_MangledComprehensionType(iterVarType, resultType);
}
}

/**
* Intermediate value class to store the mangled names for iteration variable and the
* comprehension result.
*/
@AutoValue
public abstract static class MangledComprehensionName {

/** Mangled name for iter_var */
public abstract String iterVarName();

/** Mangled name for comprehension result */
public abstract String resultName();

private static MangledComprehensionName of(String iterVarName, String resultName) {
return new AutoValue_MutableAst_MangledComprehensionName(iterVarName, resultName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
private static final String BIND_IDENTIFIER_PREFIX = "@r";
private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c";
private static final String MANGLED_COMPREHENSION_RESULT_PREFIX = "@x";
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
private static final String BLOCK_INDEX_PREFIX = "@index";
private static final ImmutableSet<String> CSE_ALLOWED_FUNCTIONS =
Expand Down Expand Up @@ -127,7 +128,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
CelType resultType = navigableAst.getAst().getResultType();
MangledComprehensionAst mangledComprehensionAst =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
navigableAst.getAst(),
MANGLED_COMPREHENSION_IDENTIFIER_PREFIX,
MANGLED_COMPREHENSION_RESULT_PREFIX);
CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast();
CelSource sourceToModify = astToModify.getSource();

Expand Down Expand Up @@ -191,10 +194,12 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
// Add all mangled comprehension identifiers to the environment, so that the subexpressions can
// retain context to them.
mangledComprehensionAst
.mangledComprehensionIdents()
.mangledComprehensionMap()
.forEach(
(identName, type) ->
celBuilder.addVarDeclarations(CelVarDecl.newVarDeclaration(identName, type)));
(name, type) ->
celBuilder.addVarDeclarations(
CelVarDecl.newVarDeclaration(name.iterVarName(), type.iterVarType()),
CelVarDecl.newVarDeclaration(name.resultName(), type.resultType())));
// Type-check all sub-expressions then add them as block identifiers to the CEL environment
addBlockIdentsToEnv(celBuilder, subexpressions);

Expand Down Expand Up @@ -266,7 +271,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst)
CelAbstractSyntaxTree astToModify =
mutableAst
.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX)
navigableAst.getAst(),
MANGLED_COMPREHENSION_IDENTIFIER_PREFIX,
MANGLED_COMPREHENSION_RESULT_PREFIX)
.ast();
CelSource sourceToModify = astToModify.getSource();

Expand Down Expand Up @@ -432,7 +439,7 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) {
// result, loopStep or iterRange. While result is not human authored, it needs to be
// included to extract subexpressions that are already in cel.bind macro.
CelNavigableExpr.fromExpr(parent.expr().comprehension().result()).descendants(),
CelNavigableExpr.fromExpr(parent.expr().comprehension().loopStep()).descendants(),
CelNavigableExpr.fromExpr(parent.expr().comprehension().loopStep()).allNodes(),
CelNavigableExpr.fromExpr(parent.expr().comprehension().iterRange()).allNodes())
.filter(
node ->
Expand Down
Loading

0 comments on commit 90671c0

Please sign in to comment.