diff --git a/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java b/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java index 8ab71c094..32e0233d0 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java @@ -122,8 +122,8 @@ private void appendSelect(CelExpr.CelSelect celSelect) { indent(); formatExpr(celSelect.operand()); outdent(); - append("."); - append(celSelect.field()); + appendWithoutIndent("."); + appendWithoutIndent(celSelect.field()); if (celSelect.testOnly()) { appendWithoutIndent("~presence_test"); } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java b/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java index 2c863b1f4..0efb890c1 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java @@ -336,4 +336,37 @@ public void comprehension() throws Exception { + " }\n" + "}"); } + + @Test + public void ternaryWithPresenceTest() throws Exception { + CelCompiler celCompiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .build(); + CelAbstractSyntaxTree ast = + celCompiler.compile("has(msg.single_any) ? msg.single_any : 10").getAst(); + + String formattedExpr = CelExprFormatter.format(ast.getExpr()); + + assertThat(formattedExpr) + .isEqualTo( + "CALL [5] {\n" + + " function: _?_:_\n" + + " args: {\n" + + " SELECT [4] {\n" + + " IDENT [2] {\n" + + " name: msg\n" + + " }.single_any~presence_test\n" + + " }\n" + + " SELECT [7] {\n" + + " IDENT [6] {\n" + + " name: msg\n" + + " }.single_any\n" + + " }\n" + + " CONSTANT [8] { value: 10 }\n" + + " }\n" + + "}"); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index c7f166460..feac4ce91 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -26,6 +26,24 @@ public interface CelAstOptimizer { CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) throws CelOptimizationException; + /** + * Replaces a subtree in the given expression node. This operation is intended for AST + * optimization purposes. + * + *
This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *
All expression IDs will be renumbered in a stable manner to ensure there's no ID collision
+ * between the nodes. The renumbering occurs even if the subtree was not replaced.
+ *
+ * @param celExpr Original expression node to rewrite.
+ * @param newExpr New CelExpr to replace the subtree with.
+ * @param exprIdToReplace Expression id of the subtree that is getting replaced.
+ */
+ default CelExpr replaceSubtree(CelExpr celExpr, CelExpr newExpr, long exprIdToReplace) {
+ return MutableAst.replaceSubtree(celExpr, newExpr, exprIdToReplace);
+ }
+
/**
* Replaces a subtree in the given AST. This operation is intended for AST optimization purposes.
*
diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
index 45cbb0ff6..1c3427169 100644
--- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
+++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
@@ -66,6 +66,15 @@ static CelExpr clearExprIds(CelExpr celExpr) {
return renumberExprIds((unused) -> 0, celExpr.toBuilder()).build();
}
+ /** Mutates the given {@link CelExpr} by replacing a subtree at the given index. */
+ static CelExpr replaceSubtree(CelExpr expr, CelExpr newExpr, long exprIdToReplace) {
+ return replaceSubtree(
+ CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()),
+ CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()),
+ exprIdToReplace)
+ .getExpr();
+ }
+
/**
* Mutates the given AST by replacing a subtree at a given index.
*
diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java
index d2f1ae7e6..d7eda5b44 100644
--- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java
+++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java
@@ -97,7 +97,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
int bindIdentifierIndex = 0;
int iterCount;
for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
- CelNavigableExpr cseCandidate = findCseCandidate(astToModify).orElse(null);
+ CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null);
if (cseCandidate == null) {
break;
}
@@ -107,7 +107,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) {
// Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time.
ImmutableList Specifically, this will:
+ *
+ *
+ *
+ */
+ private CelExpr normalizeForEquality(CelExpr celExpr) {
+ int iterCount;
+ for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) {
+ CelExpr presenceTestExpr =
+ CelNavigableExpr.fromExpr(celExpr)
+ .allNodes()
+ .map(CelNavigableExpr::expr)
+ .filter(expr -> expr.selectOrDefault().testOnly())
+ .findAny()
+ .orElse(null);
+ if (presenceTestExpr == null) {
+ break;
+ }
+
+ CelExpr newExpr =
+ presenceTestExpr.toBuilder()
+ .setSelect(presenceTestExpr.select().toBuilder().setTestOnly(false).build())
+ .build();
+
+ celExpr = replaceSubtree(celExpr, newExpr, newExpr.id());
+ }
+
+ if (iterCount >= cseOptions.maxIterationLimit()) {
+ throw new IllegalStateException("Max iteration count reached.");
+ }
+
+ return clearExprIds(celExpr);
+ }
+
/** Options to configure how Common Subexpression Elimination behave. */
@AutoValue
public abstract static class SubexpressionOptimizerOptions {
diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java
index 79e04fc64..98e1ff75a 100644
--- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java
+++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java
@@ -31,6 +31,7 @@
import dev.cel.common.CelOptions;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
+import dev.cel.common.types.OptionalType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
import dev.cel.extensions.CelExtensions;
@@ -75,7 +76,8 @@ public class SubexpressionOptimizerTest {
.setSingleInt64(10L)
.putMapInt32Int64(0, 1)
.putMapInt32Int64(1, 5)
- .putMapInt32Int64(2, 2)))
+ .putMapInt32Int64(2, 2)
+ .putMapStringString("key", "A")))
.build();
private static CelBuilder newCelBuilder() {
@@ -92,6 +94,7 @@ private static CelBuilder newCelBuilder() {
"custom_func",
newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT)))
.addVar("x", SimpleType.DYN)
+ .addVar("opt_x", OptionalType.create(SimpleType.DYN))
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
}
@@ -314,6 +317,13 @@ private enum CseTestCase {
"[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]",
"cel.bind(@r0, [1, 2, 3], @r0.map(@c0, @r0.map(@c1, @c1 + 1))) == "
+ "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])"),
+ INCLUSION_LIST(
+ "1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]",
+ "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, 1 in @r0, @r1 && 2 in @r0 && 3 in [3, @r0] &&"
+ + " @r1))"),
+ INCLUSION_MAP(
+ "2 in {'a': 1, 2: {true: false}, 3: {true: false}}",
+ "2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})"),
MACRO_SHADOWED_VARIABLE(
"[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3",
"cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0, @c0 - 1 > 3) ||"
@@ -322,6 +332,86 @@ private enum CseTestCase {
"size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2",
"size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))"
+ ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2"),
+ PRESENCE_TEST(
+ "has({'a': true}.a) && {'a':true}['a']",
+ "cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])"),
+ PRESENCE_TEST_WITH_TERNARY(
+ "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 : 0) == 10",
+ "cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10"),
+ PRESENCE_TEST_WITH_TERNARY_2(
+ "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 :"
+ + " msg.oneof_type.payload.single_int64 * 0) == 10",
+ "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload.single_int64, has(@r0.payload) ?"
+ + " @r1 : (@r1 * 0))) == 10"),
+ PRESENCE_TEST_WITH_TERNARY_3(
+ "(has(msg.oneof_type.payload.single_int64) ? msg.oneof_type.payload.single_int64 :"
+ + " msg.oneof_type.payload.single_int64 * 0) == 10",
+ "cel.bind(@r0, msg.oneof_type.payload, cel.bind(@r1, @r0.single_int64,"
+ + " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10"),
+ /**
+ * Input:
+ *
+ * {@code
+ * (
+ * has(msg.oneof_type) &&
+ * has(msg.oneof_type.payload) &&
+ * has(msg.oneof_type.payload.single_int64)
+ * ) ?
+ * (
+ * (
+ * has(msg.oneof_type.payload.map_string_string) &&
+ * has(msg.oneof_type.payload.map_string_string.key)
+ * ) ?
+ * msg.oneof_type.payload.map_string_string.key == "A"
+ * : false
+ * )
+ * : false
+ * }
+ *
+ * Unparsed:
+ *
+ * {@code
+ * cel.bind(
+ * @r0, msg.oneof_type,
+ * cel.bind(
+ * @r1, @r0.payload,
+ * has(msg.oneof_type) && has(@r0.payload) && has(@r1.single_int64) ?
+ * cel.bind(
+ * @r2, @r1.map_string_string,
+ * has(@r1.map_string_string) && has(@r2.key) ? @r2.key == "A" : false,
+ * )
+ * : false,
+ * ),
+ * )
+ * }
+ */
+ PRESENCE_TEST_WITH_TERNARY_NESTED(
+ "(has(msg.oneof_type) && has(msg.oneof_type.payload) &&"
+ + " has(msg.oneof_type.payload.single_int64)) ?"
+ + " ((has(msg.oneof_type.payload.map_string_string) &&"
+ + " has(msg.oneof_type.payload.map_string_string.key)) ?"
+ + " msg.oneof_type.payload.map_string_string.key == 'A' : false) : false",
+ "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload, (has(msg.oneof_type) &&"
+ + " has(@r0.payload) && has(@r1.single_int64)) ? cel.bind(@r2, @r1.map_string_string,"
+ + " (has(@r1.map_string_string) && has(@r2.key)) ? (@r2.key == \"A\") : false) :"
+ + " false))"),
+ OPTIONAL_LIST(
+ "[10, ?optional.none(), [?optional.none(), ?opt_x], [?optional.none(), ?opt_x]] == [10,"
+ + " [5], [5]]",
+ "cel.bind(@r0, [?optional.none(), ?opt_x], [10, ?optional.none(), @r0, @r0]) =="
+ + " cel.bind(@r1, [5], [10, @r1, @r1])"),
+ OPTIONAL_MAP(
+ "{?'hello': optional.of('hello')}['hello'] + {?'hello': optional.of('hello')}['hello'] =="
+ + " 'hellohello'",
+ "cel.bind(@r0, {?\"hello\": optional.of(\"hello\")}[\"hello\"], @r0 + @r0) =="
+ + " \"hellohello\""),
+ OPTIONAL_MESSAGE(
+ "TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:"
+ + " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:"
+ + " optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}.single_int64 == 5",
+ "cel.bind(@r0, TestAllTypes{"
+ + "?single_int64: optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}, "
+ + "@r0.single_int32 + @r0.single_int64) == 5"),
;
private final String source;
@@ -342,7 +432,9 @@ public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCas
assertThat(
CEL.createProgram(optimizedAst)
- .eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L)))
+ .eval(
+ ImmutableMap.of(
+ "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))))
.isEqualTo(true);
assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsed);
}
@@ -366,7 +458,9 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr
assertThat(
celWithoutMacroMap
.createProgram(optimizedAst)
- .eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L)))
+ .eval(
+ ImmutableMap.of(
+ "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L))))
.isEqualTo(true);
}
@@ -384,7 +478,8 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr
@TestParameters("{source: 'custom_func(1) + custom_func(1)'}")
// Duplicated but nested calls.
@TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}")
- // Ternary with presence test is not supported yet.
+ // This cannot be optimized. Extracting the common subexpression would presence test
+ // the bound identifier (e.g: has(@r0)), which is not valid.
@TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}")
public void cse_noop(String source) throws Exception {
CelAbstractSyntaxTree ast = CEL.compile(source).getAst();