Skip to content

Commit

Permalink
Add mangled comprehension variables as identifier declaration to the …
Browse files Browse the repository at this point in the history
…environment

PiperOrigin-RevId: 607507168
  • Loading branch information
l46kok authored and copybara-github committed Feb 16, 2024
1 parent 496ab08 commit 629f85b
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 23 deletions.
26 changes: 24 additions & 2 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.annotations.Immutable;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelSource;
Expand Down Expand Up @@ -200,10 +201,11 @@ public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast)
* @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will
* produce @c0, @c1, @c2... as new names.
*/
public CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
public MangledComprehensionAst mangleComprehensionIdentifierNames(
CelAbstractSyntaxTree ast, String newIdentPrefix) {
int iterCount;
CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast);
ImmutableSet.Builder<String> mangledComprehensionIdents = ImmutableSet.builder();
for (iterCount = 0; iterCount < iterationLimit; iterCount++) {
CelNavigableExpr comprehensionNode =
newNavigableAst
Expand All @@ -223,6 +225,7 @@ public CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
String iterVar = comprehensionExpr.comprehension().iterVar();
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
String mangledVarName = newIdentPrefix + comprehensionNestingLevel;
mangledComprehensionIdents.add(mangledVarName);

CelExpr.Builder mutatedComprehensionExpr =
mangleIdentsInComprehensionExpr(
Expand Down Expand Up @@ -251,7 +254,7 @@ public CelAbstractSyntaxTree mangleComprehensionIdentifierNames(
throw new IllegalStateException("Max iteration count reached.");
}

return newNavigableAst.getAst();
return MangledComprehensionAst.of(newNavigableAst.getAst(), mangledComprehensionIdents.build());
}

/**
Expand Down Expand Up @@ -575,6 +578,25 @@ private static int countComprehensionNestingLevel(CelNavigableExpr comprehension
return nestedLevel;
}

/**
* Intermediate value class to store the mangled identifiers for iteration variable in the
* comprehension.
*/
@AutoValue
public abstract static class MangledComprehensionAst {

/** AST after the iteration variables have been mangled. */
public abstract CelAbstractSyntaxTree ast();

/** Set of identifiers with the iteration variable mangled. */
public abstract ImmutableSet<String> mangledComprehensionIdents();

private static MangledComprehensionAst of(
CelAbstractSyntaxTree ast, ImmutableSet<String> mangledComprehensionIdents) {
return new AutoValue_MutableAst_MangledComprehensionAst(ast, mangledComprehensionIdents);
}
}

/**
* Intermediate value class to store the generated CelExpr for the bind macro and the macro call
* information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import dev.cel.common.CelSource.Extension.Component;
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelIdent;
Expand All @@ -45,6 +46,7 @@
import dev.cel.common.types.SimpleType;
import dev.cel.optimizer.CelAstOptimizer;
import dev.cel.optimizer.MutableAst;
import dev.cel.optimizer.MutableAst.MangledComprehensionAst;
import dev.cel.parser.Operator;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -90,10 +92,8 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
stream(Operator.values()).map(Operator::getFunction),
stream(Standard.Function.values()).map(Standard.Function::getFunction))
.collect(toImmutableSet());

private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);

private final SubexpressionOptimizerOptions cseOptions;
private final MutableAst mutableAst;

Expand Down Expand Up @@ -125,10 +125,13 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
// Retain the original expected result type, so that it can be reset in celBuilder at the end of
// the optimization pass.
CelType resultType = navigableAst.getAst().getResultType();
CelAbstractSyntaxTree astToModify =
MangledComprehensionAst mangledComprehensionAst =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
CelAbstractSyntaxTree astToModify = mangledComprehensionAst.ast();
CelSource sourceToModify = astToModify.getSource();
ImmutableSet<CelVarDecl> mangledIdentDecls =
newMangledIdentDecls(celBuilder, mangledComprehensionAst);

int blockIdentifierIndex = 0;
int iterCount;
Expand Down Expand Up @@ -187,6 +190,9 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(
return astToModify;
}

// Add all mangled comprehension identifiers to the environment, so that the subexpressions can
// retain context to them.
celBuilder.addVarDeclarations(mangledIdentDecls);
// Type-check all sub-expressions then add them as block identifiers to the CEL environment
addBlockIdentsToEnv(celBuilder, subexpressions);

Expand Down Expand Up @@ -254,10 +260,47 @@ private static void addBlockIdentsToEnv(CelBuilder celBuilder, List<CelExpr> sub
}
}

private static ImmutableSet<CelVarDecl> newMangledIdentDecls(
CelBuilder celBuilder, MangledComprehensionAst mangledComprehensionAst) {
if (mangledComprehensionAst.mangledComprehensionIdents().isEmpty()) {
return ImmutableSet.of();
}
CelAbstractSyntaxTree ast = mangledComprehensionAst.ast();
try {
ast = celBuilder.build().check(ast).getAst();
} catch (CelValidationException e) {
throw new IllegalStateException("Failed to type-check mangled AST.", e);
}

ImmutableSet.Builder<CelVarDecl> mangledVarDecls = ImmutableSet.builder();
for (String ident : mangledComprehensionAst.mangledComprehensionIdents()) {
CelExpr mangledIdentExpr =
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes()
.filter(node -> node.getKind().equals(Kind.IDENT))
.map(CelNavigableExpr::expr)
.filter(expr -> expr.ident().name().equals(ident))
.findAny()
.orElse(null);
if (mangledIdentExpr == null) {
break;
}

CelType mangledIdentType =
ast.getType(mangledIdentExpr.id()).orElseThrow(() -> new NoSuchElementException("?"));
mangledVarDecls.add(CelVarDecl.newVarDeclaration(ident, mangledIdentType));
}

return mangledVarDecls.build();
}

private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) {
CelAbstractSyntaxTree astToModify =
mutableAst.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX);
mutableAst
.mangleComprehensionIdentifierNames(
navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX)
.ast();
CelSource sourceToModify = astToModify.getSource();

int bindIdentifierIndex = 0;
Expand Down Expand Up @@ -526,9 +569,9 @@ public abstract static class Builder {

/**
* Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed
* to produce a more compact AST. {@link com.google.api.expr.SourceInfo.Extension} field will
* be populated in the AST to inform that special runtime support is required to evaluate the
* optimized expression.
* to produce a more compact AST. {@link CelSource.Extension} field will be populated in the
* AST to inform that special runtime support is required to evaluate the optimized
* expression.
*/
public abstract Builder enableCelBlock(boolean value);

Expand Down
9 changes: 6 additions & 3 deletions optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ public void comprehension_replaceLoopStep() throws Exception {
public void mangleComprehensionVariable_singleMacro() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst();

CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c");
CelAbstractSyntaxTree mangledAst =
MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast();

assertThat(mangledAst.getExpr().toString())
.isEqualTo(
Expand Down Expand Up @@ -741,7 +742,8 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("[x].exists(x, [x].exists(x, x == 1))").getAst();

CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c");
CelAbstractSyntaxTree mangledAst =
MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast();

assertThat(mangledAst.getExpr().toString())
.isEqualTo(
Expand Down Expand Up @@ -858,7 +860,8 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
public void mangleComprehensionVariable_hasMacro_noOp() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("has(msg.single_int64)").getAst();

CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c");
CelAbstractSyntaxTree mangledAst =
MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c").ast();

assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("has(msg.single_int64)");
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ public class SubexpressionOptimizerTest {
CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL)))
// Similarly, this is a test only decl (index0 -> @index0)
.addVarDeclarations(
CelVarDecl.newVarDeclaration("c0", SimpleType.DYN),
CelVarDecl.newVarDeclaration("c1", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index0", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index1", SimpleType.DYN),
CelVarDecl.newVarDeclaration("index2", SimpleType.DYN),
Expand Down Expand Up @@ -506,8 +508,9 @@ private enum CseTestCase {
"size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2",
"size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))"
+ ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2",
"Currently Unsupported"), // TODO: Handle comprehension variables that fall
// outside the cel.block scope
"cel.@block([@c1 + @c1, @c0 + @c0], "
+ "size([\"foo\", \"bar\"].map(@c1, [@index0, @index0])"
+ ".map(@c0, [@index1, @index1])) == 2)"),
PRESENCE_TEST(
"has({'a': true}.a) && {'a':true}['a']",
"cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])",
Expand Down Expand Up @@ -683,10 +686,6 @@ public void cse_withCelBind_macroMapUnpopulated(@TestParameter CseTestCase testC
@Test
public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCase)
throws Exception {
if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) {
// TODO: Handle comprehension variables that fall outside the cel.block scope
return;
}
CelOptimizer celOptimizer =
newCseOptimizer(
SubexpressionOptimizerOptions.newBuilder()
Expand All @@ -709,10 +708,6 @@ public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCa
@Test
public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase testCase)
throws Exception {
if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) {
// TODO: Handle comprehension variables that fall outside the cel.block scope
return;
}
CelOptimizer celOptimizer =
newCseOptimizer(
SubexpressionOptimizerOptions.newBuilder()
Expand All @@ -732,6 +727,32 @@ public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase test
.isEqualTo(true);
}

@Test
public void celBlock_nestedComprehension_iterVarReferencedAcrossComprehensions()
throws Exception {
String nestedComprehension =
"[\"foo\"].map(x, [[\"bar\"], [x + x, x + x]] + [\"bar\"].map(y, [x + y, [\"baz\"].map(z,"
+ " [x + y + z, x + y, x + y + z])])) == [[[\"bar\"], [\"foofoo\", \"foofoo\"],"
+ " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]]";
CelOptimizer celOptimizer =
newCseOptimizer(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.build());
CelAbstractSyntaxTree ast = CEL.compile(nestedComprehension).getAst();

CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast);

assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true);
assertThat(CEL_UNPARSER.unparse(optimizedAst))
.isEqualTo(
"cel.@block([@c0 + @c0, [\"bar\"], @c0 + @c1, @index2 + @c2], [\"foo\"].map(@c0,"
+ " [@index1, [@index0, @index0]] + @index1.map(@c1, [@index2, [\"baz\"].map(@c2,"
+ " [@index3, @index2, @index3])])) == [[@index1, [\"foofoo\", \"foofoo\"],"
+ " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]])");
}

@Test
public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception {
Cel cel = newCelBuilder().setResultType(SimpleType.BOOL).build();
Expand Down Expand Up @@ -1264,6 +1285,37 @@ public void lazyEval_multipleBlockIndices_cascaded() throws Exception {
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@SuppressWarnings("Immutable") // Test only
public void lazyEval_nestedComprehension_indexReferencedInNestedScopes() throws Exception {
AtomicInteger invocation = new AtomicInteger();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addFunctionBindings(
CelFunctionBinding.from(
"get_true_overload",
ImmutableList.of(),
arg -> {
invocation.getAndIncrement();
return true;
}))
.build();
// Equivalent of [true, false, true].map(c0, [c0].map(c1, [c0, c1, true]))
CelAbstractSyntaxTree ast =
compileUsingInternalFunctions(
"cel.block([c0, c1, get_true()], [index2, false, index2].map(c0, [c0].map(c1, [index0,"
+ " index1, index2]))) == [[[true, true, true]], [[false, false, true]], [[true,"
+ " true, true]]]");

boolean result = (boolean) celRuntime.createProgram(ast).eval();

assertThat(result).isTrue();
// Even though the function get_true() is referenced across different comprehension scopes,
// it still gets memoized only once.
assertThat(invocation.get()).isEqualTo(1);
}

@Test
@TestParameters("{source: 'cel.block([])'}")
@TestParameters("{source: 'cel.block([1])'}")
Expand Down

0 comments on commit 629f85b

Please sign in to comment.