Skip to content

Commit

Permalink
Allow setting nesting limit for extractable subexpressions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609114886
  • Loading branch information
l46kok authored and copybara-github committed Feb 21, 2024
1 parent 1e12305 commit 5ebf44e
Show file tree
Hide file tree
Showing 4 changed files with 627 additions and 172 deletions.
106 changes: 65 additions & 41 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelSource;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelComprehension;
import dev.cel.common.ast.CelExpr.CelIdent;
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
Expand All @@ -41,6 +42,7 @@
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder;
import dev.cel.common.types.CelType;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
Expand Down Expand Up @@ -132,6 +134,26 @@ public CelAbstractSyntaxTree replaceSubtree(
exprIdToReplace);
}

/** Wraps the given AST and its subexpressions with a new cel.@block call. */
public CelAbstractSyntaxTree wrapAstWithNewCelBlock(
String celBlockFunction, CelAbstractSyntaxTree ast, Collection<CelExpr> subexpressions) {
long maxId = getMaxId(ast);
CelExpr blockExpr =
CelExpr.newBuilder()
.setId(++maxId)
.setCall(
CelCall.newBuilder()
.setFunction(celBlockFunction)
.addArgs(
CelExpr.ofCreateListExpr(
++maxId, ImmutableList.copyOf(subexpressions), ImmutableList.of()),
ast.getExpr())
.build())
.build();

return CelAbstractSyntaxTree.newParsedAst(blockExpr, ast.getSource());
}

/**
* Generates a new bind macro using the provided initialization and result expression, then
* replaces the subtree using the new bind expr at the designated expr ID.
Expand Down Expand Up @@ -233,49 +255,47 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
.filter(comprehensionIdentifierPredicate)
.filter(
node -> {
// Ensure the iter_var is actually referenced in the loop_step. If it's not, we
// Ensure the iter_var or the comprehension result is actually referenced in the
// loop_step. If it's not, we
// can skip mangling.
String iterVar = node.expr().comprehension().iterVar();
String result = node.expr().comprehension().result().ident().name();
return CelNavigableExpr.fromExpr(node.expr().comprehension().loopStep())
.allNodes()
.filter(subNode -> subNode.getKind().equals(Kind.IDENT))
.map(subNode -> subNode.expr().ident())
.anyMatch(
subNode -> subNode.expr().identOrDefault().name().contains(iterVar));
ident -> ident.name().contains(iterVar) || ident.name().contains(result));
})
.collect(
Collectors.toMap(
k -> k,
v -> {
CelComprehension comprehension = v.expr().comprehension();
String iterVar = comprehension.iterVar();
long iterVarId =
// Identifiers to mangle could be the iteration variable, comprehension result
// or both, but at least one has to exist.
// As an example, [1,2].map(i, 3) would produce an optional.empty because `i`
// is not actually used.
Optional<Long> iterVarId =
CelNavigableExpr.fromExpr(comprehension.loopStep())
.allNodes()
.filter(
loopStepNode ->
loopStepNode.expr().identOrDefault().name().equals(iterVar))
.map(CelNavigableExpr::id)
.findAny()
.orElseThrow(
() -> {
throw new NoSuchElementException(
"Expected iteration variable to exist in expr id: "
+ v.id());
});

CelType iterVarType =
ast.getType(iterVarId)
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for iteration variable: "
+ iterVarId));
CelType resultType =
ast.getType(comprehension.result().id())
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for result: "
+ comprehension.result().id()));
.findAny();
Optional<CelType> iterVarType =
iterVarId.map(
id ->
ast.getType(id)
.orElseThrow(
() ->
new NoSuchElementException(
"Checked type not present for iteration variable:"
+ " "
+ iterVarId)));
Optional<CelType> resultType = ast.getType(comprehension.result().id());

return MangledComprehensionType.of(iterVarType, resultType);
},
Expand Down Expand Up @@ -487,24 +507,26 @@ private CelSource mangleIdentsInMacroSource(
// Note that in the above example, the iteration variable is not replaced after normalization.
// This is because populating a macro call map upon parse generates a new unique identifier
// that does not exist in the main AST. Thus, we need to manually replace the identifier.
// Also note that this only applies when the macro is at leaf. For nested macros, the iteration
// variable actually exists in the main AST thus, this step isn't needed.
// ex: [1].map(x, [2].filter(y, x == y). Here, the variable declaration `x` exists in the AST
// but not `y`.
CelExpr.Builder macroExpr = newSource.getMacroCalls().get(originalComprehensionId).toBuilder();
// By convention, the iteration variable is always the first argument of the
// macro call expression.
CelExpr identToMangle = macroExpr.call().args().get(0);
if (!identToMangle.identOrDefault().name().equals(originalIterVar)) {
throw new IllegalStateException(
String.format(
"Expected %s for iteration variable but got %s instead.",
identToMangle.identOrDefault().name(), originalIterVar));
if (identToMangle.identOrDefault().name().equals(originalIterVar)) {
macroExpr =
mutateExpr(
NO_OP_ID_GENERATOR,
macroExpr,
CelExpr.newBuilder()
.setIdent(
CelIdent.newBuilder()
.setName(mangledComprehensionName.iterVarName())
.build()),
identToMangle.id());
}
macroExpr =
mutateExpr(
NO_OP_ID_GENERATOR,
macroExpr,
CelExpr.newBuilder()
.setIdent(
CelIdent.newBuilder().setName(mangledComprehensionName.iterVarName()).build()),
identToMangle.id());

newSource.addMacroCalls(originalComprehensionId, macroExpr.build());
return newSource.build();
Expand Down Expand Up @@ -737,12 +759,14 @@ private static MangledComprehensionAst of(
public abstract static class MangledComprehensionType {

/** Type of iter_var */
public abstract CelType iterVarType();
public abstract Optional<CelType> iterVarType();

/** Type of comprehension result */
public abstract CelType resultType();
public abstract Optional<CelType> resultType();

private static MangledComprehensionType of(CelType iterVarType, CelType resultType) {
private static MangledComprehensionType of(
Optional<CelType> iterVarType, Optional<CelType> resultType) {
Preconditions.checkArgument(iterVarType.isPresent() || resultType.isPresent());
return new AutoValue_MutableAst_MangledComprehensionType(iterVarType, resultType);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ java_library(
"//common/navigation",
"//common/types",
"//common/types:type_providers",
"//extensions:optional_library",
"//optimizer:ast_optimizer",
"//optimizer:mutable_ast",
"//parser:operator",
Expand Down
Loading

0 comments on commit 5ebf44e

Please sign in to comment.