Skip to content

Commit

Permalink
Perform CSE on presence tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599886305
  • Loading branch information
l46kok authored and copybara-github committed Jan 19, 2024
1 parent 5a7cbab commit cd11baf
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 12 deletions.
4 changes: 2 additions & 2 deletions common/src/main/java/dev/cel/common/ast/CelExprFormatter.java
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ private void appendSelect(CelExpr.CelSelect celSelect) {
indent();
formatExpr(celSelect.operand());
outdent();
append(".");
append(celSelect.field());
appendWithoutIndent(".");
appendWithoutIndent(celSelect.field());
if (celSelect.testOnly()) {
appendWithoutIndent("~presence_test");
}
Expand Down
33 changes: 33 additions & 0 deletions common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,37 @@ public void comprehension() throws Exception {
+ " }\n"
+ "}");
}

@Test
public void ternaryWithPresenceTest() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()))
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.build();
CelAbstractSyntaxTree ast =
celCompiler.compile("has(msg.single_any) ? msg.single_any : 10").getAst();

String formattedExpr = CelExprFormatter.format(ast.getExpr());

assertThat(formattedExpr)
.isEqualTo(
"CALL [5] {\n"
+ " function: _?_:_\n"
+ " args: {\n"
+ " SELECT [4] {\n"
+ " IDENT [2] {\n"
+ " name: msg\n"
+ " }.single_any~presence_test\n"
+ " }\n"
+ " SELECT [7] {\n"
+ " IDENT [6] {\n"
+ " name: msg\n"
+ " }.single_any\n"
+ " }\n"
+ " CONSTANT [8] { value: 10 }\n"
+ " }\n"
+ "}");
}
}
18 changes: 18 additions & 0 deletions optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,24 @@ public interface CelAstOptimizer {
CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel)
throws CelOptimizationException;

/**
* Replaces a subtree in the given expression node. This operation is intended for AST
* optimization purposes.
*
* <p>This is a very dangerous operation. Callers should re-typecheck the mutated AST and
* additionally verify that the resulting AST is semantically valid.
*
* <p>All expression IDs will be renumbered in a stable manner to ensure there's no ID collision
* between the nodes. The renumbering occurs even if the subtree was not replaced.
*
* @param celExpr Original expression node to rewrite.
* @param newExpr New CelExpr to replace the subtree with.
* @param exprIdToReplace Expression id of the subtree that is getting replaced.
*/
default CelExpr replaceSubtree(CelExpr celExpr, CelExpr newExpr, long exprIdToReplace) {
return MutableAst.replaceSubtree(celExpr, newExpr, exprIdToReplace);
}

/**
* Replaces a subtree in the given AST. This operation is intended for AST optimization purposes.
*
Expand Down
9 changes: 9 additions & 0 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ static CelExpr clearExprIds(CelExpr celExpr) {
return renumberExprIds((unused) -> 0, celExpr.toBuilder()).build();
}

/** Mutates the given {@link CelExpr} by replacing a subtree at the given index. */
static CelExpr replaceSubtree(CelExpr expr, CelExpr newExpr, long exprIdToReplace) {
return replaceSubtree(
CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()),
CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()),
exprIdToReplace)
.getExpr();
}

/**
* Mutates the given AST by replacing a subtree at a given index.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
int bindIdentifierIndex = 0;
int iterCount;
for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
CelNavigableExpr cseCandidate = findCseCandidate(astToModify).orElse(null);
CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null);
if (cseCandidate == null) {
break;
}
Expand All @@ -107,7 +107,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {

// Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time.
ImmutableList<CelExpr> allCseCandidates =
getAllCseCandidatesStream(astToModify, cseCandidate.expr()).collect(toImmutableList());
getAllCseCandidatesStream(astToModify, cseCandidate).collect(toImmutableList());

// Replace all CSE candidates with new bind identifier
for (CelExpr semanticallyEqualNode : allCseCandidates) {
Expand Down Expand Up @@ -142,7 +142,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
// Insert the new bind call
astToModify =
replaceSubtreeWithNewBindMacro(
astToModify, bindIdentifier, cseCandidate.expr(), lca.expr(), lca.id());
astToModify, bindIdentifier, cseCandidate, lca.expr(), lca.id());

// Retain the existing macro calls in case if the bind identifiers are replacing a subtree
// that contains a comprehension.
Expand Down Expand Up @@ -224,8 +224,8 @@ private Optional<CelNavigableExpr> findCseCandidate(CelAbstractSyntaxTree ast) {
.collect(toImmutableList());

for (CelNavigableExpr node : allNodes) {
// Strip out all IDs to test equivalence
CelExpr celExpr = clearExprIds(node.expr());
// Normalize the expr to test semantic equivalence.
CelExpr celExpr = normalizeForEquality(node.expr());
if (encounteredNodes.contains(celExpr)) {
return Optional.of(node);
}
Expand All @@ -240,6 +240,7 @@ private static boolean canEliminate(CelNavigableExpr navigableExpr) {
return !navigableExpr.getKind().equals(Kind.CONSTANT)
&& !navigableExpr.getKind().equals(Kind.IDENT)
&& !navigableExpr.expr().identOrDefault().name().startsWith(BIND_IDENTIFIER_PREFIX)
&& !navigableExpr.expr().selectOrDefault().testOnly()
&& isAllowedFunction(navigableExpr)
&& isWithinInlineableComprehension(navigableExpr);
}
Expand Down Expand Up @@ -271,7 +272,7 @@ private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) {
}

private boolean areSemanticallyEqual(CelExpr expr1, CelExpr expr2) {
return clearExprIds(expr1).equals(clearExprIds(expr2));
return normalizeForEquality(expr1).equals(normalizeForEquality(expr2));
}

private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) {
Expand All @@ -282,6 +283,47 @@ private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) {
return true;
}

/**
* Converts the {@link CelExpr} to make it suitable for performing semantically equals check in
* {@link #areSemanticallyEqual(CelExpr, CelExpr)}.
*
* <p>Specifically, this will:
*
* <ul>
* <li>Set all expr IDs in the expression tree to 0.
* <li>Strip all presence tests (i.e: testOnly is marked as false on {@link
* CelExpr.ExprKind.Kind#SELECT}
* </ul>
*/
private CelExpr normalizeForEquality(CelExpr celExpr) {
int iterCount;
for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
CelExpr presenceTestExpr =
CelNavigableExpr.fromExpr(celExpr)
.allNodes()
.map(CelNavigableExpr::expr)
.filter(expr -> expr.selectOrDefault().testOnly())
.findAny()
.orElse(null);
if (presenceTestExpr == null) {
break;
}

CelExpr newExpr =
presenceTestExpr.toBuilder()
.setSelect(presenceTestExpr.select().toBuilder().setTestOnly(false).build())
.build();

celExpr = replaceSubtree(celExpr, newExpr, newExpr.id());
}

if (iterCount >= cseOptions.maxIterationLimit()) {
throw new IllegalStateException("Max iteration count reached.");
}

return clearExprIds(celExpr);
}

/** Options to configure how Common Subexpression Elimination behave. */
@AutoValue
public abstract static class SubexpressionOptimizerOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import dev.cel.common.CelOptions;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.types.OptionalType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
import dev.cel.extensions.CelExtensions;
Expand Down Expand Up @@ -75,7 +76,8 @@ public class SubexpressionOptimizerTest {
.setSingleInt64(10L)
.putMapInt32Int64(0, 1)
.putMapInt32Int64(1, 5)
.putMapInt32Int64(2, 2)))
.putMapInt32Int64(2, 2)
.putMapStringString("key", "A")))
.build();

private static CelBuilder newCelBuilder() {
Expand All @@ -92,6 +94,7 @@ private static CelBuilder newCelBuilder() {
"custom_func",
newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT)))
.addVar("x", SimpleType.DYN)
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
}

Expand Down Expand Up @@ -314,6 +317,13 @@ private enum CseTestCase {
"[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]",
"cel.bind(@r0, [1, 2, 3], @r0.map(@c0, @r0.map(@c1, @c1 + 1))) == "
+ "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])"),
INCLUSION_LIST(
"1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]",
"cel.bind(@r0, [1, 2, 3], cel.bind(@r1, 1 in @r0, @r1 && 2 in @r0 && 3 in [3, @r0] &&"
+ " @r1))"),
INCLUSION_MAP(
"2 in {'a': 1, 2: {true: false}, 3: {true: false}}",
"2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})"),
MACRO_SHADOWED_VARIABLE(
"[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3",
"cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0, @c0 - 1 > 3) ||"
Expand All @@ -322,6 +332,86 @@ private enum CseTestCase {
"size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2",
"size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))"
+ ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2"),
PRESENCE_TEST(
"has({'a': true}.a) && {'a':true}['a']",
"cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])"),
PRESENCE_TEST_WITH_TERNARY(
"(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 : 0) == 10",
"cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10"),
PRESENCE_TEST_WITH_TERNARY_2(
"(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 :"
+ " msg.oneof_type.payload.single_int64 * 0) == 10",
"cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload.single_int64, has(@r0.payload) ?"
+ " @r1 : (@r1 * 0))) == 10"),
PRESENCE_TEST_WITH_TERNARY_3(
"(has(msg.oneof_type.payload.single_int64) ? msg.oneof_type.payload.single_int64 :"
+ " msg.oneof_type.payload.single_int64 * 0) == 10",
"cel.bind(@r0, msg.oneof_type.payload, cel.bind(@r1, @r0.single_int64,"
+ " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10"),
/**
* Input:
*
* <pre>{@code
* (
* has(msg.oneof_type) &&
* has(msg.oneof_type.payload) &&
* has(msg.oneof_type.payload.single_int64)
* ) ?
* (
* (
* has(msg.oneof_type.payload.map_string_string) &&
* has(msg.oneof_type.payload.map_string_string.key)
* ) ?
* msg.oneof_type.payload.map_string_string.key == "A"
* : false
* )
* : false
* }</pre>
*
* Unparsed:
*
* <pre>{@code
* cel.bind(
* @r0, msg.oneof_type,
* cel.bind(
* @r1, @r0.payload,
* has(msg.oneof_type) && has(@r0.payload) && has(@r1.single_int64) ?
* cel.bind(
* @r2, @r1.map_string_string,
* has(@r1.map_string_string) && has(@r2.key) ? @r2.key == "A" : false,
* )
* : false,
* ),
* )
* }</pre>
*/
PRESENCE_TEST_WITH_TERNARY_NESTED(
"(has(msg.oneof_type) && has(msg.oneof_type.payload) &&"
+ " has(msg.oneof_type.payload.single_int64)) ?"
+ " ((has(msg.oneof_type.payload.map_string_string) &&"
+ " has(msg.oneof_type.payload.map_string_string.key)) ?"
+ " msg.oneof_type.payload.map_string_string.key == 'A' : false) : false",
"cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload, (has(msg.oneof_type) &&"
+ " has(@r0.payload) && has(@r1.single_int64)) ? cel.bind(@r2, @r1.map_string_string,"
+ " (has(@r1.map_string_string) && has(@r2.key)) ? (@r2.key == \"A\") : false) :"
+ " false))"),
OPTIONAL_LIST(
"[10, ?optional.none(), [?optional.none(), ?opt_x], [?optional.none(), ?opt_x]] == [10,"
+ " [5], [5]]",
"cel.bind(@r0, [?optional.none(), ?opt_x], [10, ?optional.none(), @r0, @r0]) =="
+ " cel.bind(@r1, [5], [10, @r1, @r1])"),
OPTIONAL_MAP(
"{?'hello': optional.of('hello')}['hello'] + {?'hello': optional.of('hello')}['hello'] =="
+ " 'hellohello'",
"cel.bind(@r0, {?\"hello\": optional.of(\"hello\")}[\"hello\"], @r0 + @r0) =="
+ " \"hellohello\""),
OPTIONAL_MESSAGE(
"TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:"
+ " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:"
+ " optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}.single_int64 == 5",
"cel.bind(@r0, TestAllTypes{"
+ "?single_int64: optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}, "
+ "@r0.single_int32 + @r0.single_int64) == 5"),
;

private final String source;
Expand All @@ -342,7 +432,9 @@ public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCas

assertThat(
CEL.createProgram(optimizedAst)
.eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L)))
.eval(
ImmutableMap.of(
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))))
.isEqualTo(true);
assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsed);
}
Expand All @@ -366,7 +458,9 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr
assertThat(
celWithoutMacroMap
.createProgram(optimizedAst)
.eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L)))
.eval(
ImmutableMap.of(
"msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))))
.isEqualTo(true);
}

Expand All @@ -384,7 +478,8 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr
@TestParameters("{source: 'custom_func(1) + custom_func(1)'}")
// Duplicated but nested calls.
@TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}")
// Ternary with presence test is not supported yet.
// This cannot be optimized. Extracting the common subexpression would presence test
// the bound identifier (e.g: has(@r0)), which is not valid.
@TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}")
public void cse_noop(String source) throws Exception {
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();
Expand Down

0 comments on commit cd11baf

Please sign in to comment.