From 70ef6f94b713aa100bf1fa1f7cabca32830c2743 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 15 Feb 2024 16:08:03 -0800 Subject: [PATCH] Augment CSE to produce optimized ASTs using cel.block PiperOrigin-RevId: 607486802 --- .../dev/cel/optimizer/CelAstOptimizer.java | 4 +- .../dev/cel/optimizer/CelOptimizerImpl.java | 8 +- .../optimizers/ConstantFoldingOptimizer.java | 4 +- .../optimizers/SubexpressionOptimizer.java | 160 ++++- .../SubexpressionOptimizerTest.java | 546 +++++++++++++++--- 5 files changed, 642 insertions(+), 80 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index b9f35dcf..730b5cc3 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -14,7 +14,7 @@ package dev.cel.optimizer; -import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.navigation.CelNavigableAst; @@ -22,6 +22,6 @@ public interface CelAstOptimizer { /** Optimizes a single AST. */ - CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) + CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder cel) throws CelOptimizationException; } diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java index 8ce74e1a..d57181fc 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelValidationException; import dev.cel.common.navigation.CelNavigableAst; @@ -39,11 +40,12 @@ public CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelOptim } CelAbstractSyntaxTree optimizedAst = ast; + CelBuilder celBuilder = cel.toCelBuilder(); try { for (CelAstOptimizer optimizer : astOptimizers) { - CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); - optimizedAst = optimizer.optimize(navigableAst, cel); - optimizedAst = cel.check(optimizedAst).getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(optimizedAst); + optimizedAst = optimizer.optimize(navigableAst, celBuilder); + optimizedAst = celBuilder.build().check(optimizedAst).getAst(); } } catch (CelValidationException e) { throw new CelOptimizationException( diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index 2e63c036..8ca19e94 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -19,6 +19,7 @@ import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelValidationException; import dev.cel.common.ast.CelConstant; @@ -76,8 +77,9 @@ public static ConstantFoldingOptimizer newInstance( private final MutableAst mutableAst; @Override - public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) + public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder celBuilder) throws CelOptimizationException { + Cel cel = celBuilder.build(); Set visitedExprs = new HashSet<>(); int iterCount = 0; while (true) { diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 241680bb..150a91dc 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -23,13 +23,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; -import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; import dev.cel.checker.Standard; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelSource; +import dev.cel.common.CelValidationException; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.navigation.CelNavigableAst; @@ -41,8 +43,10 @@ import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.MutableAst; import dev.cel.parser.Operator; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.NoSuchElementException; import java.util.Optional; import java.util.stream.Stream; @@ -63,6 +67,12 @@ * cel.bind(@r0, message.child.text_map[x], * @r0.startsWith("hello") && @r0.endsWith("world")) * } + * + * Or, using the equivalent form of cel.@block (requires special runtime support): + * {@code + * cel.block([message.child.text_map[x]], + * @index0.startsWith("hello") && @index1.endsWith("world")) + * } * */ public class SubexpressionOptimizer implements CelAstOptimizer { @@ -71,6 +81,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer { private static final String BIND_IDENTIFIER_PREFIX = "@r"; private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c"; private static final String CEL_BLOCK_FUNCTION = "cel.@block"; + private static final String BLOCK_INDEX_PREFIX = "@index"; private static final ImmutableSet CSE_ALLOWED_FUNCTIONS = Streams.concat( stream(Operator.values()).map(Operator::getFunction), @@ -96,7 +107,138 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c } @Override - public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { + public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, CelBuilder celBuilder) { + return cseOptions.enableCelBlock() + ? optimizeUsingCelBlock(navigableAst, celBuilder) + : optimizeUsingCelBind(navigableAst); + } + + private CelAbstractSyntaxTree optimizeUsingCelBlock( + CelNavigableAst navigableAst, CelBuilder celBuilder) { + // Retain the original expected result type, so that it can be reset in celBuilder at the end of + // the optimization pass. + CelType resultType = navigableAst.getAst().getResultType(); + CelAbstractSyntaxTree astToModify = + mutableAst.mangleComprehensionIdentifierNames( + navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); + CelSource sourceToModify = astToModify.getSource(); + + int blockIdentifierIndex = 0; + int iterCount; + ArrayList subexpressions = new ArrayList<>(); + for (iterCount = 0; iterCount < cseOptions.iterationLimit(); iterCount++) { + CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null); + if (cseCandidate == null) { + break; + } + subexpressions.add(cseCandidate); + + String blockIdentifier = BLOCK_INDEX_PREFIX + blockIdentifierIndex++; + + // Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time. + ImmutableList allCseCandidates = + getAllCseCandidatesStream(astToModify, cseCandidate).collect(toImmutableList()); + + // Replace all CSE candidates with new block index identifier + for (CelExpr semanticallyEqualNode : allCseCandidates) { + iterCount++; + // Refetch the candidate expr as mutating the AST could have renumbered its IDs. + CelExpr exprToReplace = + getAllCseCandidatesStream(astToModify, semanticallyEqualNode) + .findAny() + .orElseThrow( + () -> + new NoSuchElementException( + "No value present for expr ID: " + semanticallyEqualNode.id())); + + astToModify = + mutableAst.replaceSubtree( + astToModify, + CelExpr.newBuilder() + .setIdent(CelIdent.newBuilder().setName(blockIdentifier).build()) + .build(), + exprToReplace.id()); + } + + sourceToModify = + sourceToModify.toBuilder() + .addAllMacroCalls(astToModify.getSource().getMacroCalls()) + .build(); + astToModify = CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), sourceToModify); + + // Retain the existing macro calls in case if the block identifiers are replacing a subtree + // that contains a comprehension. + sourceToModify = astToModify.getSource(); + } + + if (iterCount >= cseOptions.iterationLimit()) { + throw new IllegalStateException("Max iteration count reached."); + } + + if (iterCount == 0) { + // No modification has been made. + return astToModify; + } + + // Type-check all sub-expressions then add them as block identifiers to the CEL environment + addBlockIdentsToEnv(celBuilder, subexpressions); + + // Wrap the optimized expression in cel.block + celBuilder.addFunctionDeclarations(newCelBlockFunctionDecl(resultType)); + int newId = 0; + CelExpr blockExpr = + CelExpr.newBuilder() + .setId(++newId) + .setCall( + CelCall.newBuilder() + .setFunction(CEL_BLOCK_FUNCTION) + .addArgs( + CelExpr.ofCreateListExpr( + ++newId, ImmutableList.copyOf(subexpressions), ImmutableList.of()), + astToModify.getExpr()) + .build()) + .build(); + astToModify = + mutableAst.renumberIdsConsecutively( + CelAbstractSyntaxTree.newParsedAst(blockExpr, astToModify.getSource())); + + if (!cseOptions.populateMacroCalls()) { + astToModify = + CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build()); + } + + // Restore the expected result type the environment had prior to optimization. + celBuilder.setResultType(resultType); + return astToModify; + } + + /** + * Adds all subexpression as numbered identifiers that acts as an indexer to cel.block + * (ex: @index0, @index1..) Each subexpressions are type-checked, then its result type is used as + * the new identifiers' types. + */ + private static void addBlockIdentsToEnv(CelBuilder celBuilder, List subexpressions) { + // The resulting type of the subexpressions will likely be different from the + // entire expression's expected result type. + celBuilder.setResultType(SimpleType.DYN); + + for (int i = 0; i < subexpressions.size(); i++) { + CelExpr subexpression = subexpressions.get(i); + + CelAbstractSyntaxTree subAst = + CelAbstractSyntaxTree.newParsedAst(subexpression, CelSource.newBuilder().build()); + + try { + subAst = celBuilder.build().check(subAst).getAst(); + } catch (CelValidationException e) { + throw new IllegalStateException("Failed to type-check subexpression", e); + } + + celBuilder.addVar("@index" + i, subAst.getResultType()); + } + } + + private CelAbstractSyntaxTree optimizeUsingCelBind(CelNavigableAst navigableAst) { CelAbstractSyntaxTree astToModify = mutableAst.mangleComprehensionIdentifierNames( navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); @@ -166,12 +308,13 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { return astToModify; } + astToModify = mutableAst.renumberIdsConsecutively(astToModify); if (!cseOptions.populateMacroCalls()) { astToModify = CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build()); } - return mutableAst.renumberIdsConsecutively(astToModify); + return astToModify; } private Stream getAllCseCandidatesStream( @@ -347,6 +490,8 @@ public abstract static class SubexpressionOptimizerOptions { public abstract boolean populateMacroCalls(); + public abstract boolean enableCelBlock(); + /** Builder for configuring the {@link SubexpressionOptimizerOptions}. */ @AutoValue.Builder public abstract static class Builder { @@ -363,6 +508,12 @@ public abstract static class Builder { */ public abstract Builder populateMacroCalls(boolean value); + /** + * Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed + * to produce a more compact AST. + */ + public abstract Builder enableCelBlock(boolean value); + public abstract SubexpressionOptimizerOptions build(); Builder() {} @@ -372,7 +523,8 @@ public abstract static class Builder { public static Builder newBuilder() { return new AutoValue_SubexpressionOptimizer_SubexpressionOptimizerOptions.Builder() .iterationLimit(500) - .populateMacroCalls(false); + .populateMacroCalls(false) + .enableCelBlock(false); } SubexpressionOptimizerOptions() {} diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 70b12f53..377756fc 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -98,13 +98,6 @@ public class SubexpressionOptimizerTest { .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) .build(); - private static final CelOptimizer CEL_OPTIMIZER = - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) - .addAstOptimizers( - SubexpressionOptimizer.newInstance( - SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build())) - .build(); - private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); private static final TestAllTypes TEST_ALL_TYPES_INPUT = @@ -141,12 +134,24 @@ private static CelBuilder newCelBuilder() { .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); } + private static CelOptimizer newCseOptimizer(SubexpressionOptimizerOptions options) { + return CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers(SubexpressionOptimizer.newInstance(options)) + .build(); + } + @Test - public void cse_producesOptimizedAst() throws Exception { + public void cse_withCelBind_producesOptimizedAst() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("size([0]) + size([0]) + size([1,2]) + size([1,2])").getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(false) + .enableCelBlock(false) + .build()) + .optimize(ast); assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(6); assertThat(optimizedAst.getExpr().toString()) @@ -245,24 +250,102 @@ public void cse_producesOptimizedAst() throws Exception { + "}"); } + @Test + public void cse_withCelBlock_producesOptimizedAst() throws Exception { + CelAbstractSyntaxTree ast = + CEL.compile("size([0]) + size([0]) + size([1,2]) + size([1,2])").getAst(); + CelOptimizer celOptimizer = + newCseOptimizer(SubexpressionOptimizerOptions.newBuilder().enableCelBlock(true).build()); + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(6); + assertThat(optimizedAst.getExpr().toString()) + .isEqualTo( + "CALL [1] {\n" + + " function: cel.@block\n" + + " args: {\n" + + " CREATE_LIST [2] {\n" + + " elements: {\n" + + " CALL [3] {\n" + + " function: size\n" + + " args: {\n" + + " CREATE_LIST [4] {\n" + + " elements: {\n" + + " CONSTANT [5] { value: 0 }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " CALL [6] {\n" + + " function: size\n" + + " args: {\n" + + " CREATE_LIST [7] {\n" + + " elements: {\n" + + " CONSTANT [8] { value: 1 }\n" + + " CONSTANT [9] { value: 2 }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " CALL [10] {\n" + + " function: _+_\n" + + " args: {\n" + + " CALL [11] {\n" + + " function: _+_\n" + + " args: {\n" + + " CALL [12] {\n" + + " function: _+_\n" + + " args: {\n" + + " IDENT [13] {\n" + + " name: @index0\n" + + " }\n" + + " IDENT [14] {\n" + + " name: @index0\n" + + " }\n" + + " }\n" + + " }\n" + + " IDENT [15] {\n" + + " name: @index1\n" + + " }\n" + + " }\n" + + " }\n" + + " IDENT [16] {\n" + + " name: @index1\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"); + } + private enum CseTestCase { - SIZE_1("size([1,2]) + size([1,2]) + 1 == 5", "cel.bind(@r0, size([1, 2]), @r0 + @r0) + 1 == 5"), + SIZE_1( + "size([1,2]) + size([1,2]) + 1 == 5", + "cel.bind(@r0, size([1, 2]), @r0 + @r0) + 1 == 5", + "cel.@block([size([1, 2])], @index0 + @index0 + 1 == 5)"), SIZE_2( "2 + size([1,2]) + size([1,2]) + 1 == 7", - "cel.bind(@r0, size([1, 2]), 2 + @r0 + @r0) + 1 == 7"), + "cel.bind(@r0, size([1, 2]), 2 + @r0 + @r0) + 1 == 7", + "cel.@block([size([1, 2])], 2 + @index0 + @index0 + 1 == 7)"), SIZE_3( "size([0]) + size([0]) + size([1,2]) + size([1,2]) == 6", - "cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), @r0 + @r0) + @r1 + @r1) == 6"), + "cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), @r0 + @r0) + @r1 + @r1) == 6", + "cel.@block([size([0]), size([1, 2])], @index0 + @index0 + @index1 + @index1 == 6)"), SIZE_4( "5 + size([0]) + size([0]) + size([1,2]) + size([1,2]) + " + "size([1,2,3]) + size([1,2,3]) == 17", "cel.bind(@r2, size([1, 2, 3]), cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), 5 +" - + " @r0 + @r0) + @r1 + @r1) + @r2 + @r2) == 17"), + + " @r0 + @r0) + @r1 + @r1) + @r2 + @r2) == 17", + "cel.@block([size([0]), size([1, 2]), size([1, 2, 3])], 5 + @index0 + @index0 + @index1 +" + + " @index1 + @index2 + @index2 == 17)"), /** - * Unparsed form is: + * Unparsed form: * *
      * {@code
+     * // With binds
      * cel.bind(@r0, timestamp(int(timestamp(1000000000))).getFullYear(),
      *    cel.bind(@r3, timestamp(int(timestamp(75))),
      *      cel.bind(@r2, timestamp(int(timestamp(200))).getFullYear(),
@@ -274,6 +357,21 @@ private enum CseTestCase {
      *) == 13934
      * }
      * 
+ *
+     * {@code
+     * // With block
+     * cel.@block(
+     *     [
+     *      timestamp(int(timestamp(1000000000))).getFullYear(),
+     *      timestamp(int(timestamp(50))),
+     *      timestamp(int(timestamp(200))).getFullYear(),
+     *      timestamp(int(timestamp(75)))
+     *     ],
+     *     @index0 + @index3.getFullYear() + @index1.getFullYear() + @index0 +
+     *     @index1.getSeconds() + @index2 + @index2 + @index3.getMinutes() + @index0 == 13934
+     * )
+     * 
+ * } */ TIMESTAMP( "timestamp(int(timestamp(1000000000))).getFullYear() +" @@ -290,10 +388,16 @@ private enum CseTestCase { + "cel.bind(@r2, timestamp(int(timestamp(200))).getFullYear(), " + "cel.bind(@r1, timestamp(int(timestamp(50))), " + "@r0 + @r3.getFullYear() + @r1.getFullYear() + " - + "@r0 + @r1.getSeconds()) + @r2 + @r2) + @r3.getMinutes()) + @r0) == 13934"), + + "@r0 + @r1.getSeconds()) + @r2 + @r2) + @r3.getMinutes()) + @r0) == 13934", + "cel.@block([timestamp(int(timestamp(1000000000))).getFullYear()," + + " timestamp(int(timestamp(50))), timestamp(int(timestamp(200))).getFullYear()," + + " timestamp(int(timestamp(75)))], @index0 + @index3.getFullYear() +" + + " @index1.getFullYear() + @index0 + @index1.getSeconds() + @index2 + @index2 +" + + " @index3.getMinutes() + @index0 == 13934)"), MAP_INDEX( "{\"a\": 2}[\"a\"] + {\"a\": 2}[\"a\"] * {\"a\": 2}[\"a\"] == 6", - "cel.bind(@r0, {\"a\": 2}[\"a\"], @r0 + @r0 * @r0) == 6"), + "cel.bind(@r0, {\"a\": 2}[\"a\"], @r0 + @r0 * @r0) == 6", + "cel.@block([{\"a\": 2}[\"a\"]], @index0 + @index0 * @index0 == 6)"), /** * Input map is: * @@ -313,84 +417,117 @@ private enum CseTestCase { NESTED_MAP_CONSTRUCTION( "size({'a': {'b': 1}, 'c': {'b': 1}, 'd': {'e': {'b': 1}}, 'e': {'e': {'b': 1}}}) == 4", "size(cel.bind(@r0, {\"b\": 1}, cel.bind(@r1, {\"e\": @r0}, {\"a\": @r0, \"c\": @r0, \"d\":" - + " @r1, \"e\": @r1}))) == 4"), + + " @r1, \"e\": @r1}))) == 4", + "cel.@block([{\"b\": 1}, {\"e\": @index0}], size({\"a\": @index0, \"c\": @index0, \"d\":" + + " @index1, \"e\": @index1}) == 4)"), NESTED_LIST_CONSTRUCTION( "size([1, [1,2,3,4], 2, [1,2,3,4], 5, [1,2,3,4], 7, [[1,2], [1,2,3,4]], [1,2]]) == 9", "size(cel.bind(@r0, [1, 2, 3, 4], " - + "cel.bind(@r1, [1, 2], [1, @r0, 2, @r0, 5, @r0, 7, [@r1, @r0], @r1]))) == 9"), + + "cel.bind(@r1, [1, 2], [1, @r0, 2, @r0, 5, @r0, 7, [@r1, @r0], @r1]))) == 9", + "cel.@block([[1, 2, 3, 4], [1, 2]], size([1, @index0, 2, @index0, 5, @index0, 7, [@index1," + + " @index0], @index1]) == 9)"), SELECT( "msg.single_int64 + msg.single_int64 == 6", - "cel.bind(@r0, msg.single_int64, @r0 + @r0) == 6"), + "cel.bind(@r0, msg.single_int64, @r0 + @r0) == 6", + "cel.@block([msg.single_int64], @index0 + @index0 == 6)"), SELECT_NESTED( "msg.oneof_type.payload.single_int64 + msg.oneof_type.payload.single_int32 + " + "msg.oneof_type.payload.single_int64 + " + "msg.single_int64 + msg.oneof_type.payload.oneof_type.payload.single_int64 == 31", "cel.bind(@r0, msg.oneof_type.payload, " + "cel.bind(@r1, @r0.single_int64, @r1 + @r0.single_int32 + @r1) + " - + "msg.single_int64 + @r0.oneof_type.payload.single_int64) == 31"), + + "msg.single_int64 + @r0.oneof_type.payload.single_int64) == 31", + "cel.@block([msg.oneof_type.payload, @index0.single_int64], @index1 + @index0.single_int32" + + " + @index1 + msg.single_int64 + @index0.oneof_type.payload.single_int64 == 31)"), SELECT_NESTED_MESSAGE_MAP_INDEX_1( "msg.oneof_type.payload.map_int32_int64[1] + " + "msg.oneof_type.payload.map_int32_int64[1] + " + "msg.oneof_type.payload.map_int32_int64[1] == 15", - "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64[1], @r0 + @r0 + @r0) == 15"), + "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64[1], @r0 + @r0 + @r0) == 15", + "cel.@block([msg.oneof_type.payload.map_int32_int64[1]], @index0 + @index0 + @index0 ==" + + " 15)"), SELECT_NESTED_MESSAGE_MAP_INDEX_2( "msg.oneof_type.payload.map_int32_int64[0] + " + "msg.oneof_type.payload.map_int32_int64[1] + " + "msg.oneof_type.payload.map_int32_int64[2] == 8", - "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64, @r0[0] + @r0[1] + @r0[2]) == 8"), + "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64, @r0[0] + @r0[1] + @r0[2]) == 8", + "cel.@block([msg.oneof_type.payload.map_int32_int64], @index0[0] + @index0[1] + @index0[2]" + + " == 8)"), TERNARY( "(msg.single_int64 > 0 ? msg.single_int64 : 0) == 3", - "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? @r0 : 0) == 3"), + "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? @r0 : 0) == 3", + "cel.@block([msg.single_int64], ((@index0 > 0) ? @index0 : 0) == 3)"), TERNARY_BIND_RHS_ONLY( "false ? false : (msg.single_int64) + ((msg.single_int64 + 1) * 2) == 11", - "false ? false : (cel.bind(@r0, msg.single_int64, @r0 + (@r0 + 1) * 2) == 11)"), + "false ? false : (cel.bind(@r0, msg.single_int64, @r0 + (@r0 + 1) * 2) == 11)", + "cel.@block([msg.single_int64], false ? false : (@index0 + (@index0 + 1) * 2 == 11))"), NESTED_TERNARY( "(msg.single_int64 > 0 ? (msg.single_int32 > 0 ? " + "msg.single_int64 + msg.single_int32 : 0) : 0) == 8", "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? " - + "cel.bind(@r1, msg.single_int32, (@r1 > 0) ? (@r0 + @r1) : 0) : 0) == 8"), + + "cel.bind(@r1, msg.single_int32, (@r1 > 0) ? (@r0 + @r1) : 0) : 0) == 8", + "cel.@block([msg.single_int64, msg.single_int32], ((@index0 > 0) ? ((@index1 > 0) ?" + + " (@index0 + @index1) : 0) : 0) == 8)"), MULTIPLE_MACROS( // Note that all of these have different iteration variables, but they are still logically // the same. "size([[1].exists(i, i > 0)]) + size([[1].exists(j, j > 0)]) + " + "size([[2].exists(k, k > 1)]) + size([[2].exists(l, l > 1)]) == 4", "cel.bind(@r1, size([[2].exists(@c0, @c0 > 1)]), " - + "cel.bind(@r0, size([[1].exists(@c0, @c0 > 0)]), @r0 + @r0) + @r1 + @r1) == 4"), + + "cel.bind(@r0, size([[1].exists(@c0, @c0 > 0)]), @r0 + @r0) + @r1 + @r1) == 4", + "cel.@block([size([[1].exists(@c0, @c0 > 0)]), size([[2].exists(@c0, @c0 > 1)])], @index0 +" + + " @index0 + @index1 + @index1 == 4)"), NESTED_MACROS( "[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]", "cel.bind(@r0, [1, 2, 3], @r0.map(@c0, @r0.map(@c1, @c1 + 1))) == " - + "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])"), + + "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])", + "cel.@block([[1, 2, 3], [2, 3, 4]], @index0.map(@c0, @index0.map(@c1, @c1 + 1)) ==" + + " [@index1, @index1, @index1])"), INCLUSION_LIST( "1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]", "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, 1 in @r0, @r1 && 2 in @r0 && 3 in [3, @r0] &&" - + " @r1))"), + + " @r1))", + "cel.@block([[1, 2, 3], 1 in @index0], @index1 && 2 in @index0 && 3 in [3, @index0] &&" + + " @index1)"), INCLUSION_MAP( "2 in {'a': 1, 2: {true: false}, 3: {true: false}}", - "2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})"), + "2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})", + "cel.@block([{true: false}], 2 in {\"a\": 1, 2: @index0, 3: @index0})"), MACRO_SHADOWED_VARIABLE( "[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3", "cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0, @c0 - 1 > 3) ||" - + " @r1))"), + + " @r1))", + "cel.@block([x - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@c0, @c0 - 1 > 3) ||" + + " @index1)"), MACRO_SHADOWED_VARIABLE_2( "size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2", "size([\"foo\", \"bar\"].map(@c1, cel.bind(@r0, @c1 + @c1, [@r0, @r0]))" - + ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2"), + + ".map(@c0, cel.bind(@r1, @c0 + @c0, [@r1, @r1]))) == 2", + "Currently Unsupported"), // TODO: Handle comprehension variables that fall + // outside the cel.block scope PRESENCE_TEST( "has({'a': true}.a) && {'a':true}['a']", - "cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])"), + "cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])", + "cel.@block([{\"a\": true}], has(@index0.a) && @index0[\"a\"])"), PRESENCE_TEST_WITH_TERNARY( "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 : 0) == 10", - "cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10"), + "cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10", + "cel.@block([msg.oneof_type], (has(@index0.payload) ? @index0.payload.single_int64 : 0) ==" + + " 10)"), PRESENCE_TEST_WITH_TERNARY_2( "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 :" + " msg.oneof_type.payload.single_int64 * 0) == 10", "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload.single_int64, has(@r0.payload) ?" - + " @r1 : (@r1 * 0))) == 10"), + + " @r1 : (@r1 * 0))) == 10", + "cel.@block([msg.oneof_type, @index0.payload.single_int64], (has(@index0.payload) ? @index1" + + " : (@index1 * 0)) == 10)"), PRESENCE_TEST_WITH_TERNARY_3( "(has(msg.oneof_type.payload.single_int64) ? msg.oneof_type.payload.single_int64 :" + " msg.oneof_type.payload.single_int64 * 0) == 10", "cel.bind(@r0, msg.oneof_type.payload, cel.bind(@r1, @r0.single_int64," - + " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10"), + + " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10", + "cel.@block([msg.oneof_type.payload, @index0.single_int64], (has(@index0.single_int64) ?" + + " @index1 : (@index1 * 0)) == 10)"), /** * Input: * @@ -414,6 +551,7 @@ private enum CseTestCase { * Unparsed: * *
{@code
+     * // With binds
      * cel.bind(
      *   @r0, msg.oneof_type,
      *   cel.bind(
@@ -427,6 +565,22 @@ private enum CseTestCase {
      *   ),
      * )
      * }
+ *
{@code
+     * // With block
+     * cel.@block(
+     *   [
+     *     msg.oneof_type,
+     *     @index0.payload,
+     *     @index1.map_string_string
+     *   ],
+     *   (has(msg.oneof_type) && has(@index0.payload) && has(@index1.single_int64)) ?
+     *     (
+     *       (has(@index1.map_string_string) && has(@index2.key)) ?
+     *         (@index2.key == "A") : false
+     *     )
+     *     : false
+     *   )
+     * }
*/ PRESENCE_TEST_WITH_TERNARY_NESTED( "(has(msg.oneof_type) && has(msg.oneof_type.payload) &&" @@ -437,41 +591,59 @@ private enum CseTestCase { "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload, (has(msg.oneof_type) &&" + " has(@r0.payload) && has(@r1.single_int64)) ? cel.bind(@r2, @r1.map_string_string," + " (has(@r1.map_string_string) && has(@r2.key)) ? (@r2.key == \"A\") : false) :" - + " false))"), + + " false))", + "cel.@block([msg.oneof_type, @index0.payload, @index1.map_string_string]," + + " (has(msg.oneof_type) && has(@index0.payload) && has(@index1.single_int64)) ?" + + " ((has(@index1.map_string_string) && has(@index2.key)) ? (@index2.key == \"A\") :" + + " false) : false)"), OPTIONAL_LIST( "[10, ?optional.none(), [?optional.none(), ?opt_x], [?optional.none(), ?opt_x]] == [10," + " [5], [5]]", "cel.bind(@r0, [?optional.none(), ?opt_x], [10, ?optional.none(), @r0, @r0]) ==" - + " cel.bind(@r1, [5], [10, @r1, @r1])"), + + " cel.bind(@r1, [5], [10, @r1, @r1])", + "cel.@block([[?optional.none(), ?opt_x], [5]], [10, ?optional.none(), @index0, @index0] ==" + + " [10, @index1, @index1])"), OPTIONAL_MAP( "{?'hello': optional.of('hello')}['hello'] + {?'hello': optional.of('hello')}['hello'] ==" + " 'hellohello'", "cel.bind(@r0, {?\"hello\": optional.of(\"hello\")}[\"hello\"], @r0 + @r0) ==" - + " \"hellohello\""), + + " \"hellohello\"", + "cel.@block([{?\"hello\": optional.of(\"hello\")}[\"hello\"]], @index0 + @index0 ==" + + " \"hellohello\")"), OPTIONAL_MESSAGE( "TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:" + " optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}.single_int64 == 5", "cel.bind(@r0, TestAllTypes{" + "?single_int64: optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}, " - + "@r0.single_int32 + @r0.single_int64) == 5"), + + "@r0.single_int32 + @r0.single_int64) == 5", + "cel.@block([TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + + " optional.of(4)}], @index0.single_int32 + @index0.single_int64 == 5)"), ; private final String source; - private final String unparsed; + private final String unparsedBind; + private final String unparsedBlock; - CseTestCase(String source, String unparsed) { + CseTestCase(String source, String unparsedBind, String unparsedBlock) { this.source = source; - this.unparsed = unparsed; + this.unparsedBind = unparsedBind; + this.unparsedBlock = unparsedBlock; } } @Test - public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCase) + public void cse_withCelBind_macroMapPopulated(@TestParameter CseTestCase testCase) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(false) + .build()) + .optimize(ast); assertThat( CEL.createProgram(optimizedAst) @@ -479,27 +651,25 @@ public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCas ImmutableMap.of( "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) .isEqualTo(true); - assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsed); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsedBind); } @Test - public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) throws Exception { - Cel celWithoutMacroMap = - newCelBuilder() - .setOptions( - CelOptions.current().populateMacroCalls(false).enableTimestampEpoch(true).build()) - .build(); - CelAbstractSyntaxTree ast = celWithoutMacroMap.compile(testCase.source).getAst(); + public void cse_withCelBind_macroMapUnpopulated(@TestParameter CseTestCase testCase) + throws Exception { + CelBuilder celWithoutMacroMap = + newCelBuilder().setOptions(CelOptions.current().enableTimestampEpoch(true).build()); + CelAbstractSyntaxTree ast = celWithoutMacroMap.build().compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = - CelOptimizerFactory.standardCelOptimizerBuilder(celWithoutMacroMap) - .addAstOptimizers(SubexpressionOptimizer.getInstance()) - .build() + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(false).build()) .optimize(ast); assertThat(optimizedAst.getSource().getMacroCalls()).isEmpty(); assertThat( celWithoutMacroMap + .build() .createProgram(optimizedAst) .eval( ImmutableMap.of( @@ -507,6 +677,108 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr .isEqualTo(true); } + @Test + public void cse_withCelBlock_macroMapPopulated(@TestParameter CseTestCase testCase) + throws Exception { + if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) { + // TODO: Handle comprehension variables that fall outside the cel.block scope + return; + } + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat( + CEL.createProgram(optimizedAst) + .eval( + ImmutableMap.of( + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) + .isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsedBlock); + } + + @Test + public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase testCase) + throws Exception { + if (testCase.equals(CseTestCase.MACRO_SHADOWED_VARIABLE_2)) { + // TODO: Handle comprehension variables that fall outside the cel.block scope + return; + } + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(false) + .enableCelBlock(true) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(optimizedAst.getSource().getMacroCalls()).isEmpty(); + assertThat( + CEL.createProgram(optimizedAst) + .eval( + ImmutableMap.of( + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) + .isEqualTo(true); + } + + @Test + public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { + Cel cel = newCelBuilder().setResultType(SimpleType.BOOL).build(); + CelOptimizer celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder().enableCelBlock(true).build())) + .build(); + CelAbstractSyntaxTree ast = CEL.compile("size('a') + size('a') == 2").getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo("cel.@block([size(\"a\")], @index0 + @index0 == 2)"); + } + + @Test + // Nothing to optimize + @TestParameters("{source: 'size(\"hello\")'}") + // Constants and identifiers + @TestParameters("{source: '2 + 2 + 2 + 2'}") + @TestParameters("{source: 'x + x + x + x'}") + @TestParameters("{source: 'true == true && false == false'}") + // Constants and identifiers within a function + @TestParameters("{source: 'size(\"hello\" + \"hello\" + \"hello\")'}") + @TestParameters("{source: 'string(x + x + x)'}") + // Non-standard functions are considered non-pure for time being + @TestParameters("{source: 'custom_func(1) + custom_func(1)'}") + // Duplicated but nested calls. + @TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}") + // This cannot be optimized. Extracting the common subexpression would presence test + // the bound identifier (e.g: has(@r0)), which is not valid. + @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") + public void cse_withCelBind_noop(String source) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(false) + .build()) + .optimize(ast); + + assertThat(ast.getExpr()).isEqualTo(optimizedAst.getExpr()); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); + } + @Test // Nothing to optimize @TestParameters("{source: 'size(\"hello\")'}") @@ -524,10 +796,16 @@ public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) thr // This cannot be optimized. Extracting the common subexpression would presence test // the bound identifier (e.g: has(@r0)), which is not valid. @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") - public void cse_noop(String source) throws Exception { + public void cse_withCelBlock_noop(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()) + .optimize(ast); assertThat(ast.getExpr()).isEqualTo(optimizedAst.getExpr()); assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); @@ -548,7 +826,13 @@ public void cse_largeCalcExpr() throws Exception { } CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(false) + .build()) + .optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo( @@ -582,7 +866,13 @@ public void cse_largeNestedBinds() throws Exception { } CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(false) + .build()) + .optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo( @@ -620,7 +910,58 @@ public void cse_largeNestedBinds() throws Exception { } @Test - public void cse_largeNestedMacro() throws Exception { + public void cse_largeFlattenedBlocks() throws Exception { + StringBuilder sb = new StringBuilder(); + int limit = 50; + for (int i = 0; i < limit; i++) { + sb.append(String.format("size([%d, %d]) + ", i, i + 1)); + sb.append(String.format("size([%d, %d]) ", i, i + 1)); + if (i < limit - 1) { + sb.append("+"); + } + } + CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); + + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()) + .optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo( + "cel.@block([size([0, 1]), size([1, 2]), size([2, 3]), size([3, 4]), size([4, 5])," + + " size([5, 6]), size([6, 7]), size([7, 8]), size([8, 9]), size([9, 10])," + + " size([10, 11]), size([11, 12]), size([12, 13]), size([13, 14]), size([14, 15])," + + " size([15, 16]), size([16, 17]), size([17, 18]), size([18, 19]), size([19, 20])," + + " size([20, 21]), size([21, 22]), size([22, 23]), size([23, 24]), size([24, 25])," + + " size([25, 26]), size([26, 27]), size([27, 28]), size([28, 29]), size([29, 30])," + + " size([30, 31]), size([31, 32]), size([32, 33]), size([33, 34]), size([34, 35])," + + " size([35, 36]), size([36, 37]), size([37, 38]), size([38, 39]), size([39, 40])," + + " size([40, 41]), size([41, 42]), size([42, 43]), size([43, 44]), size([44, 45])," + + " size([45, 46]), size([46, 47]), size([47, 48]), size([48, 49]), size([49," + + " 50])], @index0 + @index0 + @index1 + @index1 + @index2 + @index2 + @index3 +" + + " @index3 + @index4 + @index4 + @index5 + @index5 + @index6 + @index6 + @index7 +" + + " @index7 + @index8 + @index8 + @index9 + @index9 + @index10 + @index10 +" + + " @index11 + @index11 + @index12 + @index12 + @index13 + @index13 + @index14 +" + + " @index14 + @index15 + @index15 + @index16 + @index16 + @index17 + @index17 +" + + " @index18 + @index18 + @index19 + @index19 + @index20 + @index20 + @index21 +" + + " @index21 + @index22 + @index22 + @index23 + @index23 + @index24 + @index24 +" + + " @index25 + @index25 + @index26 + @index26 + @index27 + @index27 + @index28 +" + + " @index28 + @index29 + @index29 + @index30 + @index30 + @index31 + @index31 +" + + " @index32 + @index32 + @index33 + @index33 + @index34 + @index34 + @index35 +" + + " @index35 + @index36 + @index36 + @index37 + @index37 + @index38 + @index38 +" + + " @index39 + @index39 + @index40 + @index40 + @index41 + @index41 + @index42 +" + + " @index42 + @index43 + @index43 + @index44 + @index44 + @index45 + @index45 +" + + " @index46 + @index46 + @index47 + @index47 + @index48 + @index48 + @index49 +" + + " @index49)"); + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(200L); + } + + @Test + public void cse_withCelBind_largeNestedMacro() throws Exception { StringBuilder sb = new StringBuilder(); sb.append("size([1,2,3]"); int limit = 8; @@ -639,7 +980,13 @@ public void cse_largeNestedMacro() throws Exception { } CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); - CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(false) + .build()) + .optimize(ast); assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo( @@ -649,6 +996,43 @@ public void cse_largeNestedMacro() throws Exception { assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(27); } + @Test + public void cse_withCelBlock_largeNestedMacro() throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("size([1,2,3]"); + int limit = 8; + for (int i = 0; i < limit; i++) { + sb.append(".map(i, [1, 2, 3]"); + } + for (int i = 0; i < limit; i++) { + sb.append(")"); + } + sb.append(")"); + String nestedMapCallExpr = sb.toString(); // size([1,2,3].map(i, [1,2,3].map(i, [1,2,3].map(... + // Add this large macro call 8 times + for (int i = 0; i < limit; i++) { + sb.append("+"); + sb.append(nestedMapCallExpr); + } + CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); + + CelAbstractSyntaxTree optimizedAst = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .build()) + .optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo( + "cel.@block([[1, 2, 3], size(@index0.map(@c0, @index0.map(@c1, @index0.map(@c2," + + " @index0.map(@c3, @index0.map(@c4, @index0.map(@c5, @index0.map(@c6," + + " @index0.map(@c7, @index0)))))))))], @index1 + @index1 + @index1 + @index1 +" + + " @index1 + @index1 + @index1 + @index1 + @index1)"); + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(27); + } + @Test public void cse_applyConstFoldingAfter() throws Exception { CelAbstractSyntaxTree ast = @@ -658,7 +1042,7 @@ public void cse_applyConstFoldingAfter() throws Exception { CelOptimizerFactory.standardCelOptimizerBuilder(CEL) .addAstOptimizers( SubexpressionOptimizer.newInstance( - SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), + SubexpressionOptimizerOptions.newBuilder().build()), ConstantFoldingOptimizer.getInstance()) .build(); @@ -677,7 +1061,31 @@ public void cse_applyConstFoldingAfter() throws Exception { } @Test - public void iterationLimitReached_throws() throws Exception { + @TestParameters("{enableCelBlock: false, unparsed: 'cel.bind(@r0, size(x), @r0 + @r0)'}") + @TestParameters("{enableCelBlock: true, unparsed: 'cel.@block([size(x)], @index0 + @index0)'}") + public void cse_applyConstFoldingAfter_nothingToFold(boolean enableCelBlock, String unparsed) + throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(enableCelBlock) + .build()), + ConstantFoldingOptimizer.getInstance()) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(unparsed); + } + + @Test + @TestParameters("{enableCelBlock: false}") + @TestParameters("{enableCelBlock: true}") + public void iterationLimitReached_throws(boolean enableCelBlock) throws Exception { StringBuilder largeExprBuilder = new StringBuilder(); int iterationLimit = 100; for (int i = 0; i < iterationLimit; i++) { @@ -692,13 +1100,11 @@ public void iterationLimitReached_throws() throws Exception { assertThrows( CelOptimizationException.class, () -> - CelOptimizerFactory.standardCelOptimizerBuilder(CEL) - .addAstOptimizers( - SubexpressionOptimizer.newInstance( - SubexpressionOptimizerOptions.newBuilder() - .iterationLimit(iterationLimit) - .build())) - .build() + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .iterationLimit(iterationLimit) + .enableCelBlock(enableCelBlock) + .build()) .optimize(ast)); assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached."); }