Skip to content

Commit

Permalink
Implement Optimizer for Common Subexpression Elimination
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599696579
  • Loading branch information
l46kok authored and copybara-github committed Jan 19, 2024
1 parent 8dffc6c commit 790e8cf
Show file tree
Hide file tree
Showing 13 changed files with 1,111 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ bazel-cel-java
bazel-out
bazel-testlogs

MODULE.bazel*

# IntelliJ IDEA
.idea
*.iml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public boolean hasId(long id) {
return idSet.containsKey(id);
}

/** Generates the next available ID. */
public long nextExprId() {
return ++exprId;
}

/**
* Generate the next available ID while memoizing the existing ID.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public enum TraversalOrder {

public abstract CelExpr expr();

public long id() {
return expr().id();
}

public abstract Optional<CelNavigableExpr> parent();

/** Represents the count of transitive parents. Depth of an AST's root is 0. */
Expand Down
6 changes: 6 additions & 0 deletions optimizer/optimizers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ java_library(
name = "constant_folding",
exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:constant_folding"],
)

java_library(
name = "common_subexpression_elimination",
visibility = ["//visibility:public"], # TODO: Expose when ready
exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:common_subexpression_elimination"],
)
32 changes: 32 additions & 0 deletions optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,36 @@ default CelAbstractSyntaxTree replaceSubtree(
CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) {
return MutableAst.replaceSubtree(ast, newExpr, exprIdToReplace);
}

/**
* Generates a new bind macro using the provided initialization and result expression, then
* replaces the subtree using the new bind expr at the designated expr ID.
*
* <p>The bind call takes the format of: {@code cel.bind(varInit, varName, resultExpr)}
*
* @param ast Original ast to mutate.
* @param varName New variable name for the bind macro call.
* @param varInit Initialization expression to bind to the local variable.
* @param resultExpr Result expression
* @param exprIdToReplace Expression ID of the subtree that is getting replaced.
*/
default CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro(
CelAbstractSyntaxTree ast,
String varName,
CelExpr varInit,
CelExpr resultExpr,
long exprIdToReplace) {
return MutableAst.replaceSubtreeWithNewBindMacro(
ast, varName, varInit, resultExpr, exprIdToReplace);
}

/** Sets all expr IDs in the expression tree to 0. */
default CelExpr clearExprIds(CelExpr celExpr) {
return MutableAst.clearExprIds(celExpr);
}

/** Renumbers all the expr IDs in the given AST in a consecutive manner starting from 1. */
default CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) {
return MutableAst.renumberIdsConsecutively(ast);
}
}
59 changes: 44 additions & 15 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import dev.cel.common.ast.CelExpr.CelCreateMap;
import dev.cel.common.ast.CelExpr.CelCreateStruct;
import dev.cel.common.ast.CelExpr.CelSelect;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
import dev.cel.common.ast.CelExprFactory;
import dev.cel.common.ast.CelExprIdGeneratorFactory;
import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator;
Expand All @@ -42,7 +43,7 @@
/** MutableAst contains logic for mutating a {@link CelExpr}. */
@Internal
final class MutableAst {
private static final int MAX_ITERATION_COUNT = 500;
private static final int MAX_ITERATION_COUNT = 1000;
private final CelExpr.Builder newExpr;
private final ExprIdGenerator celExprIdGenerator;
private int iterationCount;
Expand Down Expand Up @@ -91,7 +92,8 @@ static CelAbstractSyntaxTree replaceSubtree(
// Mutate the AST root with the new subtree. All the existing expr IDs are renumbered in the
// process, but its original IDs are memoized so that we can normalize the expr IDs
// in the macro source map.
StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0);
StableIdGenerator stableIdGenerator =
CelExprIdGeneratorFactory.newStableIdGenerator(getMaxId(newAst));
CelExpr.Builder mutatedRoot =
replaceSubtreeImpl(
stableIdGenerator::renumberId,
Expand Down Expand Up @@ -147,13 +149,26 @@ static CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro(
ast, CelAbstractSyntaxTree.newParsedAst(bindMacro.bindExpr(), celSource), exprIdToReplace);
}

static CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) {
StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0);
CelExpr.Builder root =
renumberExprIds(stableIdGenerator::renumberId, ast.getExpr().toBuilder());
CelSource newSource =
normalizeMacroSource(
ast.getSource(), Integer.MIN_VALUE, root, stableIdGenerator::renumberId);

return CelAbstractSyntaxTree.newParsedAst(root.build(), newSource);
}

private static BindMacro newBindMacro(
String varName, CelExpr varInit, CelExpr resultExpr, StableIdGenerator stableIdGenerator) {
// Clear incoming expression IDs in the initialization expression to avoid collision with the
// main AST.
varInit = clearExprIds(varInit);
// Renumber incoming expression IDs in the init and result expression to avoid collision with
// the main AST. Existing IDs are memoized for a macro source sanitization pass at the end
// (e.g: inserting a bind macro to an existing macro expr)
varInit = renumberExprIds(stableIdGenerator::nextExprId, varInit.toBuilder()).build();
resultExpr = renumberExprIds(stableIdGenerator::nextExprId, resultExpr.toBuilder()).build();
CelExprFactory exprFactory =
CelExprFactory.newInstance((unused) -> stableIdGenerator.nextExprId(-1));
CelExprFactory.newInstance((unused) -> stableIdGenerator.nextExprId());
CelExpr bindMacroExpr =
exprFactory.fold(
"#unused",
Expand All @@ -164,17 +179,12 @@ private static BindMacro newBindMacro(
exprFactory.newIdentifier(varName),
resultExpr);

// Update the IDs in the new expression tree first. This ensures that no ID collision
// occurs while attempting to replace the subtree later, potentially leading to an infinite loop
bindMacroExpr =
renumberExprIds(stableIdGenerator::nextExprId, bindMacroExpr.toBuilder()).build();

CelExpr bindMacroCallExpr =
exprFactory
.newReceiverCall(
"bind",
CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(-1), "cel"),
CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(-1), varName),
CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(), "cel"),
CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(), varName),
bindMacroExpr.comprehension().accuInit(),
bindMacroExpr.comprehension().result())
.toBuilder()
Expand Down Expand Up @@ -270,6 +280,7 @@ private static CelSource normalizeMacroSource(
if (!allExprs.containsKey(callChild.id())) {
continue;
}

CelExpr mutatedExpr = allExprs.get(callChild.id());
if (!callChild.equals(mutatedExpr)) {
newCall =
Expand All @@ -279,6 +290,25 @@ private static CelSource normalizeMacroSource(
sourceBuilder.addMacroCalls(callId, newCall.build());
}

// Replace comprehension nodes with a NOT_SET reference to reduce AST size.
for (Entry<Long, CelExpr> macroCall : sourceBuilder.getMacroCalls().entrySet()) {
CelExpr macroCallExpr = macroCall.getValue();
CelNavigableExpr.fromExpr(macroCallExpr)
.allNodes()
.filter(node -> node.getKind().equals(Kind.COMPREHENSION))
.map(CelNavigableExpr::expr)
.forEach(
node -> {
CelExpr.Builder mutatedNode =
replaceSubtreeImpl(
(id) -> id,
macroCallExpr.toBuilder(),
CelExpr.ofNotSet(node.id()).toBuilder(),
node.id());
macroCall.setValue(mutatedNode.build());
});
}

return sourceBuilder.build();
}

Expand Down Expand Up @@ -309,7 +339,7 @@ private static long getMaxId(CelAbstractSyntaxTree ast) {
private static long getMaxId(CelExpr newExpr) {
return CelNavigableExpr.fromExpr(newExpr)
.allNodes()
.mapToLong(node -> node.expr().id())
.mapToLong(CelNavigableExpr::id)
.max()
.orElseThrow(NoSuchElementException::new);
}
Expand Down Expand Up @@ -419,7 +449,6 @@ private CelExpr.Builder visit(CelExpr.Builder expr, CelComprehension.Builder com
*/
@AutoValue
abstract static class BindMacro {

/** Comprehension expr for the generated cel.bind macro. */
abstract CelExpr bindExpr();

Expand Down
21 changes: 21 additions & 0 deletions optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,24 @@ java_library(
"@maven//:com_google_guava_guava",
],
)

java_library(
name = "common_subexpression_elimination",
srcs = [
"SubexpressionOptimizer.java",
],
tags = [
],
deps = [
"//:auto_value",
"//bundle:cel",
"//checker:checker_legacy_environment",
"//common",
"//common/ast",
"//common/navigation",
"//optimizer:ast_optimizer",
"//parser:operator",
"@maven//:com_google_guava_guava",
"@maven//:org_jspecify_jspecify",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,8 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel)

// If the output is a list, map, or struct which contains optional entries, then prune it
// to make sure that the optionals, if resolved, do not surface in the output literal.
navigableAst = CelNavigableAst.fromAst(pruneOptionalElements(navigableAst));

return navigableAst.getAst();
CelAbstractSyntaxTree newAst = pruneOptionalElements(navigableAst);
return renumberIdsConsecutively(newAst);
}

private static boolean canFold(CelNavigableExpr navigableExpr) {
Expand Down
Loading

0 comments on commit 790e8cf

Please sign in to comment.