Skip to content

Commit

Permalink
Perform CSE on comprehension loop step by mangling identifier names
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599877299
  • Loading branch information
l46kok authored and copybara-github committed Jan 19, 2024
1 parent 790e8cf commit 5a7cbab
Show file tree
Hide file tree
Showing 5 changed files with 392 additions and 18 deletions.
30 changes: 30 additions & 0 deletions optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ default CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro(
ast, varName, varInit, resultExpr, exprIdToReplace);
}

/**
* Replaces all comprehension identifier names with a unique name based on the given prefix.
*
* <p>The purpose of this is to avoid errors that can be caused by shadowed variables while
* augmenting an AST. As an example: {@code [2, 3].exists(x, x - 1 > 3) || x - 1 > 3}. Note that
* the scoping of `x - 1` is different between th two LOGICAL_OR branches. Iteration variable `x`
* in `exists` will be mangled to {@code [2, 3].exists(@c0, @c0 - 1 > 3) || x - 1 > 3} to avoid
* erroneously extracting x - 1 as common subexpression.
*
* <p>The expression IDs are not modified when the identifier names are changed.
*
* <p>Iteration variables in comprehensions are numbered based on their comprehension nesting
* levels. Examples:
*
* <ul>
* <li>{@code [true].exists(i, i) && [true].exists(j, j)} -> {@code [true].exists(@c0, @c0) &&
* [true].exists(@c0, @c0)} // Note that i,j gets replaced to the same @c0 in this example
* <li>{@code [true].exists(i, i && [true].exists(j, j))} -> {@code [true].exists(@c0, @c0 &&
* [true].exists(@c1, @c1))}
* </ul>
*
* @param ast AST to mutate
* @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will
* produce @c0, @c1, @c2... as new names.
*/
default CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
return MutableAst.mangleComprehensionIdentifierNames(ast, newIdentPrefix);
}

/** Sets all expr IDs in the expression tree to 0. */
default CelExpr clearExprIds(CelExpr celExpr) {
return MutableAst.clearExprIds(celExpr);
Expand Down
145 changes: 145 additions & 0 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,26 @@
import dev.cel.common.ast.CelExpr.CelCreateList;
import dev.cel.common.ast.CelExpr.CelCreateMap;
import dev.cel.common.ast.CelExpr.CelCreateStruct;
import dev.cel.common.ast.CelExpr.CelIdent;
import dev.cel.common.ast.CelExpr.CelSelect;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.ast.CelExprFactory;
import dev.cel.common.ast.CelExprIdGeneratorFactory;
import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator;
import dev.cel.common.ast.CelExprIdGeneratorFactory.StableIdGenerator;
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder;
import java.util.Map.Entry;
import java.util.NoSuchElementException;
import java.util.Optional;

/** MutableAst contains logic for mutating a {@link CelExpr}. */
@Internal
final class MutableAst {
private static final int MAX_ITERATION_COUNT = 1000;
private static final ExprIdGenerator NO_OP_ID_GENERATOR = id -> id;

private final CelExpr.Builder newExpr;
private final ExprIdGenerator celExprIdGenerator;
private int iterationCount;
Expand Down Expand Up @@ -160,6 +166,132 @@ static CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast)
return CelAbstractSyntaxTree.newParsedAst(root.build(), newSource);
}

static CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
int iterCount;
CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast);
for (iterCount = 0; iterCount < MAX_ITERATION_COUNT; iterCount++) {
Optional<CelNavigableExpr> maybeComprehensionExpr =
newNavigableAst
.getRoot()
// This is important - mangling needs to happen bottom-up to avoid stepping over
// shadowed variables that are not part of the comprehension being mangled.
.allNodes(TraversalOrder.POST_ORDER)
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix))
.findAny();
if (!maybeComprehensionExpr.isPresent()) {
break;
}

CelExpr.Builder comprehensionExpr = maybeComprehensionExpr.get().expr().toBuilder();
String iterVar = comprehensionExpr.comprehension().iterVar();
int comprehensionNestingLevel = countComprehensionNestingLevel(maybeComprehensionExpr.get());
String mangledVarName = newIdentPrefix + comprehensionNestingLevel;

CelExpr.Builder mutatedComprehensionExpr =
mangleIdentsInComprehensionExpr(
newNavigableAst.getAst().getExpr().toBuilder(),
comprehensionExpr,
iterVar,
mangledVarName);
// Repeat the mangling process for the macro source.
CelSource newSource =
mangleIdentsInMacroSource(
newNavigableAst.getAst(),
mutatedComprehensionExpr,
iterVar,
mangledVarName,
comprehensionExpr.id());

newNavigableAst =
CelNavigableAst.fromAst(
CelAbstractSyntaxTree.newParsedAst(mutatedComprehensionExpr.build(), newSource));
}

if (iterCount >= MAX_ITERATION_COUNT) {
throw new IllegalStateException("Max iteration count reached.");
}

return newNavigableAst.getAst();
}

private static CelExpr.Builder mangleIdentsInComprehensionExpr(
CelExpr.Builder root,
CelExpr.Builder comprehensionExpr,
String originalIterVar,
String mangledVarName) {
int iterCount;
for (iterCount = 0; iterCount < MAX_ITERATION_COUNT; iterCount++) {
Optional<CelExpr> identToMangle =
CelNavigableExpr.fromExpr(comprehensionExpr.comprehension().loopStep())
.descendants()
.map(CelNavigableExpr::expr)
.filter(node -> node.identOrDefault().name().equals(originalIterVar))
.findAny();
if (!identToMangle.isPresent()) {
break;
}

comprehensionExpr =
replaceSubtreeImpl(
NO_OP_ID_GENERATOR,
comprehensionExpr,
CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()),
identToMangle.get().id());
}

if (iterCount >= MAX_ITERATION_COUNT) {
throw new IllegalStateException("Max iteration count reached.");
}

return replaceSubtreeImpl(
NO_OP_ID_GENERATOR,
root,
comprehensionExpr.setComprehension(
comprehensionExpr.comprehension().toBuilder().setIterVar(mangledVarName).build()),
comprehensionExpr.id());
}

private static CelSource mangleIdentsInMacroSource(
CelAbstractSyntaxTree ast,
CelExpr.Builder mutatedComprehensionExpr,
String originalIterVar,
String mangledVarName,
long originalComprehensionId) {
if (!ast.getSource().getMacroCalls().containsKey(originalComprehensionId)) {
return ast.getSource();
}

// First, normalize the macro source.
// ex: [x].exists(x, [x].exists(x, x == 1)) -> [x].exists(x, [@c1].exists(x, @c0 == 1)).
CelSource.Builder newSource =
normalizeMacroSource(ast.getSource(), -1, mutatedComprehensionExpr, (id) -> id).toBuilder();

// Note that in the above example, the iteration variable is not replaced after normalization.
// This is because populating a macro call map upon parse generates a new unique identifier
// that does not exist in the main AST. Thus, we need to manually replace the identifier.
CelExpr.Builder macroExpr = newSource.getMacroCalls().get(originalComprehensionId).toBuilder();
// By convention, the iteration variable is always the first argument of the
// macro call expression.
CelExpr identToMangle = macroExpr.call().args().get(0);
if (!identToMangle.identOrDefault().name().equals(originalIterVar)) {
throw new IllegalStateException(
String.format(
"Expected %s for iteration variable but got %s instead.",
identToMangle.identOrDefault().name(), originalIterVar));
}
macroExpr =
replaceSubtreeImpl(
NO_OP_ID_GENERATOR,
macroExpr,
CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()),
identToMangle.id());

newSource.addMacroCalls(originalComprehensionId, macroExpr.build());
return newSource.build();
}

private static BindMacro newBindMacro(
String varName, CelExpr varInit, CelExpr resultExpr, StableIdGenerator stableIdGenerator) {
// Renumber incoming expression IDs in the init and result expression to avoid collision with
Expand Down Expand Up @@ -344,6 +476,19 @@ private static long getMaxId(CelExpr newExpr) {
.orElseThrow(NoSuchElementException::new);
}

private static int countComprehensionNestingLevel(CelNavigableExpr comprehensionExpr) {
int nestedLevel = 0;
Optional<CelNavigableExpr> maybeParent = comprehensionExpr.parent();
while (maybeParent.isPresent()) {
if (maybeParent.get().getKind().equals(Kind.COMPREHENSION)) {
nestedLevel++;
}

maybeParent = maybeParent.get().parent();
}
return nestedLevel;
}

private CelExpr.Builder visit(CelExpr.Builder expr) {
if (++iterationCount > MAX_ITERATION_COUNT) {
throw new IllegalStateException("Max iteration count reached.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
private static final SubexpressionOptimizer INSTANCE =
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
private static final String BIND_IDENTIFIER_PREFIX = "@r";
private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c";
private static final ImmutableSet<String> CSE_ALLOWED_FUNCTIONS =
Streams.concat(
stream(Operator.values()).map(Operator::getFunction),
Expand All @@ -88,8 +89,11 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c

@Override
public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
CelAbstractSyntaxTree astToModify = navigableAst.getAst();
CelAbstractSyntaxTree astToModify =
mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
CelSource sourceToModify = astToModify.getSource();

int bindIdentifierIndex = 0;
int iterCount;
for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
Expand Down Expand Up @@ -247,9 +251,10 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) {
if (parent.getKind().equals(Kind.COMPREHENSION)) {
return Streams.concat(
// If the expression is within a comprehension, it is eligible for CSE iff is in
// result or iterRange. While result is not human authored, it needs to be included
// to extract subexpressions that are already in cel.bind macro.
// result, loopStep or iterRange. While result is not human authored, it needs to be
// included to extract subexpressions that are already in cel.bind macro.
CelNavigableExpr.fromExpr(parent.expr().comprehension().result()).descendants(),
CelNavigableExpr.fromExpr(parent.expr().comprehension().loopStep()).descendants(),
CelNavigableExpr.fromExpr(parent.expr().comprehension().iterRange()).allNodes())
.filter(
node ->
Expand Down
Loading

0 comments on commit 5a7cbab

Please sign in to comment.