From 5ebf44e3ffe2559609d5d8f45e7bbd27d6efc942 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 21 Feb 2024 14:02:08 -0800 Subject: [PATCH] Allow setting nesting limit for extractable subexpressions. PiperOrigin-RevId: 609114886 --- .../java/dev/cel/optimizer/MutableAst.java | 106 ++-- .../dev/cel/optimizer/optimizers/BUILD.bazel | 1 + .../optimizers/SubexpressionOptimizer.java | 153 +++-- .../SubexpressionOptimizerTest.java | 539 +++++++++++++++--- 4 files changed, 627 insertions(+), 172 deletions(-) diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 9333726a..1a466df6 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -29,6 +29,7 @@ import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelSource; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelComprehension; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.ExprKind.Kind; @@ -41,6 +42,7 @@ import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder; import dev.cel.common.types.CelType; +import java.util.Collection; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map.Entry; @@ -132,6 +134,26 @@ public CelAbstractSyntaxTree replaceSubtree( exprIdToReplace); } + /** Wraps the given AST and its subexpressions with a new cel.@block call. */ + public CelAbstractSyntaxTree wrapAstWithNewCelBlock( + String celBlockFunction, CelAbstractSyntaxTree ast, Collection subexpressions) { + long maxId = getMaxId(ast); + CelExpr blockExpr = + CelExpr.newBuilder() + .setId(++maxId) + .setCall( + CelCall.newBuilder() + .setFunction(celBlockFunction) + .addArgs( + CelExpr.ofCreateListExpr( + ++maxId, ImmutableList.copyOf(subexpressions), ImmutableList.of()), + ast.getExpr()) + .build()) + .build(); + + return CelAbstractSyntaxTree.newParsedAst(blockExpr, ast.getSource()); + } + /** * Generates a new bind macro using the provided initialization and result expression, then * replaces the subtree using the new bind expr at the designated expr ID. @@ -233,13 +255,17 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( .filter(comprehensionIdentifierPredicate) .filter( node -> { - // Ensure the iter_var is actually referenced in the loop_step. If it's not, we + // Ensure the iter_var or the comprehension result is actually referenced in the + // loop_step. If it's not, we // can skip mangling. String iterVar = node.expr().comprehension().iterVar(); + String result = node.expr().comprehension().result().ident().name(); return CelNavigableExpr.fromExpr(node.expr().comprehension().loopStep()) .allNodes() + .filter(subNode -> subNode.getKind().equals(Kind.IDENT)) + .map(subNode -> subNode.expr().ident()) .anyMatch( - subNode -> subNode.expr().identOrDefault().name().contains(iterVar)); + ident -> ident.name().contains(iterVar) || ident.name().contains(result)); }) .collect( Collectors.toMap( @@ -247,35 +273,29 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames( v -> { CelComprehension comprehension = v.expr().comprehension(); String iterVar = comprehension.iterVar(); - long iterVarId = + // Identifiers to mangle could be the iteration variable, comprehension result + // or both, but at least one has to exist. + // As an example, [1,2].map(i, 3) would produce an optional.empty because `i` + // is not actually used. + Optional iterVarId = CelNavigableExpr.fromExpr(comprehension.loopStep()) .allNodes() .filter( loopStepNode -> loopStepNode.expr().identOrDefault().name().equals(iterVar)) .map(CelNavigableExpr::id) - .findAny() - .orElseThrow( - () -> { - throw new NoSuchElementException( - "Expected iteration variable to exist in expr id: " - + v.id()); - }); - - CelType iterVarType = - ast.getType(iterVarId) - .orElseThrow( - () -> - new NoSuchElementException( - "Checked type not present for iteration variable: " - + iterVarId)); - CelType resultType = - ast.getType(comprehension.result().id()) - .orElseThrow( - () -> - new NoSuchElementException( - "Checked type not present for result: " - + comprehension.result().id())); + .findAny(); + Optional iterVarType = + iterVarId.map( + id -> + ast.getType(id) + .orElseThrow( + () -> + new NoSuchElementException( + "Checked type not present for iteration variable:" + + " " + + iterVarId))); + Optional resultType = ast.getType(comprehension.result().id()); return MangledComprehensionType.of(iterVarType, resultType); }, @@ -487,24 +507,26 @@ private CelSource mangleIdentsInMacroSource( // Note that in the above example, the iteration variable is not replaced after normalization. // This is because populating a macro call map upon parse generates a new unique identifier // that does not exist in the main AST. Thus, we need to manually replace the identifier. + // Also note that this only applies when the macro is at leaf. For nested macros, the iteration + // variable actually exists in the main AST thus, this step isn't needed. + // ex: [1].map(x, [2].filter(y, x == y). Here, the variable declaration `x` exists in the AST + // but not `y`. CelExpr.Builder macroExpr = newSource.getMacroCalls().get(originalComprehensionId).toBuilder(); // By convention, the iteration variable is always the first argument of the // macro call expression. CelExpr identToMangle = macroExpr.call().args().get(0); - if (!identToMangle.identOrDefault().name().equals(originalIterVar)) { - throw new IllegalStateException( - String.format( - "Expected %s for iteration variable but got %s instead.", - identToMangle.identOrDefault().name(), originalIterVar)); + if (identToMangle.identOrDefault().name().equals(originalIterVar)) { + macroExpr = + mutateExpr( + NO_OP_ID_GENERATOR, + macroExpr, + CelExpr.newBuilder() + .setIdent( + CelIdent.newBuilder() + .setName(mangledComprehensionName.iterVarName()) + .build()), + identToMangle.id()); } - macroExpr = - mutateExpr( - NO_OP_ID_GENERATOR, - macroExpr, - CelExpr.newBuilder() - .setIdent( - CelIdent.newBuilder().setName(mangledComprehensionName.iterVarName()).build()), - identToMangle.id()); newSource.addMacroCalls(originalComprehensionId, macroExpr.build()); return newSource.build(); @@ -737,12 +759,14 @@ private static MangledComprehensionAst of( public abstract static class MangledComprehensionType { /** Type of iter_var */ - public abstract CelType iterVarType(); + public abstract Optional iterVarType(); /** Type of comprehension result */ - public abstract CelType resultType(); + public abstract Optional resultType(); - private static MangledComprehensionType of(CelType iterVarType, CelType resultType) { + private static MangledComprehensionType of( + Optional iterVarType, Optional resultType) { + Preconditions.checkArgument(iterVarType.isPresent() || resultType.isPresent()); return new AutoValue_MutableAst_MangledComprehensionType(iterVarType, resultType); } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 785ddf74..27932360 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -48,6 +48,7 @@ java_library( "//common/navigation", "//common/types", "//common/types:type_providers", + "//extensions:optional_library", "//optimizer:ast_optimizer", "//optimizer:mutable_ast", "//parser:operator", diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index 8cf71e75..4c7c081a 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -20,8 +20,10 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.Streams; import dev.cel.bundle.CelBuilder; import dev.cel.checker.Standard; @@ -35,7 +37,6 @@ import dev.cel.common.CelValidationException; import dev.cel.common.CelVarDecl; import dev.cel.common.ast.CelExpr; -import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.navigation.CelNavigableAst; @@ -44,6 +45,8 @@ import dev.cel.common.types.CelType; import dev.cel.common.types.ListType; import dev.cel.common.types.SimpleType; +import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.extensions.CelOptionalLibrary.Function; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.MutableAst; import dev.cel.optimizer.MutableAst.MangledComprehensionAst; @@ -91,10 +94,13 @@ public class SubexpressionOptimizer implements CelAstOptimizer { private static final ImmutableSet CSE_ALLOWED_FUNCTIONS = Streams.concat( stream(Operator.values()).map(Operator::getFunction), - stream(Standard.Function.values()).map(Standard.Function::getFunction)) + stream(Standard.Function.values()).map(Standard.Function::getFunction), + stream(CelOptionalLibrary.Function.values()).map(Function::getFunction)) .collect(toImmutableSet()); + private static final Extension CEL_BLOCK_AST_EXTENSION_TAG = Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME); + private final SubexpressionOptimizerOptions cseOptions; private final MutableAst mutableAst; @@ -186,6 +192,11 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( throw new IllegalStateException("Max iteration count reached."); } + if (!cseOptions.populateMacroCalls()) { + astToModify = + CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build()); + } + if (iterCount == 0) { // No modification has been made. return astToModify; @@ -196,36 +207,28 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock( mangledComprehensionAst .mangledComprehensionMap() .forEach( - (name, type) -> - celBuilder.addVarDeclarations( - CelVarDecl.newVarDeclaration(name.iterVarName(), type.iterVarType()), - CelVarDecl.newVarDeclaration(name.resultName(), type.resultType()))); + (name, type) -> { + type.iterVarType() + .ifPresent( + iterVarType -> + celBuilder.addVarDeclarations( + CelVarDecl.newVarDeclaration(name.iterVarName(), iterVarType))); + type.resultType() + .ifPresent( + comprehensionResultType -> + celBuilder.addVarDeclarations( + CelVarDecl.newVarDeclaration( + name.resultName(), comprehensionResultType))); + }); + // Type-check all sub-expressions then add them as block identifiers to the CEL environment addBlockIdentsToEnv(celBuilder, subexpressions); // Wrap the optimized expression in cel.block celBuilder.addFunctionDeclarations(newCelBlockFunctionDecl(resultType)); - int newId = 0; - CelExpr blockExpr = - CelExpr.newBuilder() - .setId(++newId) - .setCall( - CelCall.newBuilder() - .setFunction(CEL_BLOCK_FUNCTION) - .addArgs( - CelExpr.ofCreateListExpr( - ++newId, ImmutableList.copyOf(subexpressions), ImmutableList.of()), - astToModify.getExpr()) - .build()) - .build(); astToModify = - mutableAst.renumberIdsConsecutively( - CelAbstractSyntaxTree.newParsedAst(blockExpr, astToModify.getSource())); - - if (!cseOptions.populateMacroCalls()) { - astToModify = - CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build()); - } + mutableAst.wrapAstWithNewCelBlock(CEL_BLOCK_FUNCTION, astToModify, subexpressions); + astToModify = mutableAst.renumberIdsConsecutively(astToModify); // Restore the expected result type the environment had prior to optimization. celBuilder.setResultType(resultType); @@ -399,14 +402,47 @@ private static CelNavigableExpr getLca(CelAbstractSyntaxTree ast, String boundId } private Optional findCseCandidate(CelAbstractSyntaxTree ast) { - HashSet encounteredNodes = new HashSet<>(); + if (cseOptions.enableCelBlock() && cseOptions.subexpressionMaxRecursionDepth() > 0) { + return findCseCandidateWithRecursionDepth(ast, cseOptions.subexpressionMaxRecursionDepth()); + } else { + return findCseCandidateWithCommonSubexpr(ast); + } + } + + /** + * This retrieves a subexpr candidate based on the recursion limit even if there's no duplicate + * subexpr found. + * + *

TODO: Improve the extraction logic using a suffix tree. + */ + private Optional findCseCandidateWithRecursionDepth( + CelAbstractSyntaxTree ast, int recursionLimit) { + Preconditions.checkArgument(recursionLimit > 0); ImmutableList allNodes = CelNavigableAst.fromAst(ast) .getRoot() - .allNodes(TraversalOrder.PRE_ORDER) + .allNodes(TraversalOrder.POST_ORDER) .filter(SubexpressionOptimizer::canEliminate) + .filter(node -> node.height() <= recursionLimit) + .filter(node -> !areSemanticallyEqual(ast.getExpr(), node.expr())) .collect(toImmutableList()); + if (allNodes.isEmpty()) { + return Optional.empty(); + } + + Optional commonSubexpr = findCseCandidateWithCommonSubexpr(allNodes); + if (commonSubexpr.isPresent()) { + return commonSubexpr; + } + // If there's no common subexpr, just return the one with the highest height that's still below + // the recursion limit. + return Optional.of(Iterables.getLast(allNodes)); + } + + private Optional findCseCandidateWithCommonSubexpr( + ImmutableList allNodes) { + HashSet encounteredNodes = new HashSet<>(); for (CelNavigableExpr node : allNodes) { // Normalize the expr to test semantic equivalence. CelExpr celExpr = normalizeForEquality(node.expr()); @@ -420,12 +456,23 @@ private Optional findCseCandidate(CelAbstractSyntaxTree ast) { return Optional.empty(); } + private Optional findCseCandidateWithCommonSubexpr(CelAbstractSyntaxTree ast) { + ImmutableList allNodes = + CelNavigableAst.fromAst(ast) + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .filter(SubexpressionOptimizer::canEliminate) + .collect(toImmutableList()); + + return findCseCandidateWithCommonSubexpr(allNodes); + } + private static boolean canEliminate(CelNavigableExpr navigableExpr) { return !navigableExpr.getKind().equals(Kind.CONSTANT) && !navigableExpr.getKind().equals(Kind.IDENT) && !navigableExpr.expr().identOrDefault().name().startsWith(BIND_IDENTIFIER_PREFIX) && !navigableExpr.expr().selectOrDefault().testOnly() - && isAllowedFunction(navigableExpr) + && containsAllowedFunctionOnly(navigableExpr) && isWithinInlineableComprehension(navigableExpr); } @@ -459,12 +506,17 @@ private boolean areSemanticallyEqual(CelExpr expr1, CelExpr expr2) { return normalizeForEquality(expr1).equals(normalizeForEquality(expr2)); } - private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) { - if (navigableExpr.getKind().equals(Kind.CALL)) { - return CSE_ALLOWED_FUNCTIONS.contains(navigableExpr.expr().call().function()); - } - - return true; + private static boolean containsAllowedFunctionOnly(CelNavigableExpr navigableExpr) { + return navigableExpr + .allNodes() + .allMatch( + node -> { + if (node.getKind().equals(Kind.CALL)) { + return CSE_ALLOWED_FUNCTIONS.contains(node.expr().call().function()); + } + + return true; + }); } /** @@ -525,6 +577,8 @@ public abstract static class SubexpressionOptimizerOptions { public abstract boolean enableCelBlock(); + public abstract int subexpressionMaxRecursionDepth(); + /** Builder for configuring the {@link SubexpressionOptimizerOptions}. */ @AutoValue.Builder public abstract static class Builder { @@ -549,6 +603,32 @@ public abstract static class Builder { */ public abstract Builder enableCelBlock(boolean value); + /** + * Ensures all extracted subexpressions do not exceed the maximum depth of designated value. + * The purpose of this is to guarantee evaluation and deserialization safety by preventing + * deeply nested ASTs. The trade-off is increased memory usage due to memoizing additional + * block indices during lazy evaluation. + * + *

As a general note, root of a node has a depth of 0. An expression `x.y.z` has a depth of + * 2. + * + *

Note that expressions containing no common subexpressions may become a candidate for + * extraction to satisfy the max depth requirement. + * + *

This is a no-op if {@link #enableCelBlock} is set to false, or the configured value is + * less than 1. + * + *

Examples: + * + *

    + *
  1. a.b.c with depth 1 -> cel.@block([x.b, @index0.c], @index1) + *
  2. a.b + a.b.c.d with depth 3 -> cel.@block([a.b, @index0.c.d], @index0 + @index1) + *
+ * + *

+ */ + public abstract Builder subexpressionMaxRecursionDepth(int value); + public abstract SubexpressionOptimizerOptions build(); Builder() {} @@ -559,7 +639,8 @@ public static Builder newBuilder() { return new AutoValue_SubexpressionOptimizer_SubexpressionOptimizerOptions.Builder() .iterationLimit(500) .populateMacroCalls(false) - .enableCelBlock(false); + .enableCelBlock(false) + .subexpressionMaxRecursionDepth(0); } SubexpressionOptimizerOptions() {} diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 4b99c325..57a7485f 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -331,22 +331,30 @@ private enum CseTestCase { SIZE_1( "size([1,2]) + size([1,2]) + 1 == 5", "cel.bind(@r0, size([1, 2]), @r0 + @r0) + 1 == 5", - "cel.@block([size([1, 2])], @index0 + @index0 + 1 == 5)"), + "cel.@block([size([1, 2])], @index0 + @index0 + 1 == 5)", + "cel.@block([[1, 2], size(@index0), @index1 + @index1, @index2 + 1], @index3 == 5)"), SIZE_2( "2 + size([1,2]) + size([1,2]) + 1 == 7", "cel.bind(@r0, size([1, 2]), 2 + @r0 + @r0) + 1 == 7", - "cel.@block([size([1, 2])], 2 + @index0 + @index0 + 1 == 7)"), + "cel.@block([size([1, 2])], 2 + @index0 + @index0 + 1 == 7)", + "cel.@block([[1, 2], size(@index0), 2 + @index1, @index2 + @index1, @index3 + 1], @index4" + + " == 7)"), SIZE_3( "size([0]) + size([0]) + size([1,2]) + size([1,2]) == 6", "cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), @r0 + @r0) + @r1 + @r1) == 6", - "cel.@block([size([0]), size([1, 2])], @index0 + @index0 + @index1 + @index1 == 6)"), + "cel.@block([size([0]), size([1, 2])], @index0 + @index0 + @index1 + @index1 == 6)", + "cel.@block([[0], size(@index0), [1, 2], size(@index2), @index1 + @index1, @index4 +" + + " @index3, @index5 + @index3], @index6 == 6)"), SIZE_4( "5 + size([0]) + size([0]) + size([1,2]) + size([1,2]) + " + "size([1,2,3]) + size([1,2,3]) == 17", "cel.bind(@r2, size([1, 2, 3]), cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), 5 +" + " @r0 + @r0) + @r1 + @r1) + @r2 + @r2) == 17", "cel.@block([size([0]), size([1, 2]), size([1, 2, 3])], 5 + @index0 + @index0 + @index1 +" - + " @index1 + @index2 + @index2 == 17)"), + + " @index1 + @index2 + @index2 == 17)", + "cel.@block([[0], size(@index0), [1, 2], size(@index2), [1, 2, 3], size(@index4), 5 +" + + " @index1, @index6 + @index1, @index7 + @index3, @index8 + @index3, @index9 +" + + " @index5, @index10 + @index5], @index11 == 17)"), /** * Unparsed form: * @@ -400,11 +408,21 @@ private enum CseTestCase { + " timestamp(int(timestamp(50))), timestamp(int(timestamp(200))).getFullYear()," + " timestamp(int(timestamp(75)))], @index0 + @index3.getFullYear() +" + " @index1.getFullYear() + @index0 + @index1.getSeconds() + @index2 + @index2 +" - + " @index3.getMinutes() + @index0 == 13934)"), + + " @index3.getMinutes() + @index0 == 13934)", + "cel.@block([timestamp(1000000000), int(@index0), timestamp(@index1)," + + " @index2.getFullYear(), timestamp(50), int(@index4), timestamp(@index5)," + + " timestamp(200), int(@index7), timestamp(@index8), @index9.getFullYear()," + + " timestamp(75), int(@index11), timestamp(@index12), @index13.getMinutes()," + + " @index6.getSeconds(), @index6.getFullYear(), @index13.getFullYear(), @index3 +" + + " @index17, @index18 + @index16, @index19 + @index3, @index20 + @index15, @index21 +" + + " @index10, @index22 + @index10, @index23 + @index14, @index24 + @index3], @index25" + + " == 13934)"), MAP_INDEX( "{\"a\": 2}[\"a\"] + {\"a\": 2}[\"a\"] * {\"a\": 2}[\"a\"] == 6", "cel.bind(@r0, {\"a\": 2}[\"a\"], @r0 + @r0 * @r0) == 6", - "cel.@block([{\"a\": 2}[\"a\"]], @index0 + @index0 * @index0 == 6)"), + "cel.@block([{\"a\": 2}[\"a\"]], @index0 + @index0 * @index0 == 6)", + "cel.@block([{\"a\": 2}, @index0[\"a\"], @index1 * @index1, @index1 + @index2], @index3 ==" + + " 6)"), /** * Input map is: * @@ -426,17 +444,22 @@ private enum CseTestCase { "size(cel.bind(@r0, {\"b\": 1}, cel.bind(@r1, {\"e\": @r0}, {\"a\": @r0, \"c\": @r0, \"d\":" + " @r1, \"e\": @r1}))) == 4", "cel.@block([{\"b\": 1}, {\"e\": @index0}], size({\"a\": @index0, \"c\": @index0, \"d\":" - + " @index1, \"e\": @index1}) == 4)"), + + " @index1, \"e\": @index1}) == 4)", + "cel.@block([{\"b\": 1}, {\"e\": @index0}, {\"a\": @index0, \"c\": @index0, \"d\": @index1," + + " \"e\": @index1}, size(@index2)], @index3 == 4)"), NESTED_LIST_CONSTRUCTION( "size([1, [1,2,3,4], 2, [1,2,3,4], 5, [1,2,3,4], 7, [[1,2], [1,2,3,4]], [1,2]]) == 9", "size(cel.bind(@r0, [1, 2, 3, 4], " + "cel.bind(@r1, [1, 2], [1, @r0, 2, @r0, 5, @r0, 7, [@r1, @r0], @r1]))) == 9", "cel.@block([[1, 2, 3, 4], [1, 2]], size([1, @index0, 2, @index0, 5, @index0, 7, [@index1," - + " @index0], @index1]) == 9)"), + + " @index0], @index1]) == 9)", + "cel.@block([[1, 2, 3, 4], [1, 2], [@index1, @index0], [1, @index0, 2, @index0, 5, @index0," + + " 7, @index2, @index1], size(@index3)], @index4 == 9)"), SELECT( "msg.single_int64 + msg.single_int64 == 6", "cel.bind(@r0, msg.single_int64, @r0 + @r0) == 6", - "cel.@block([msg.single_int64], @index0 + @index0 == 6)"), + "cel.@block([msg.single_int64], @index0 + @index0 == 6)", + "cel.@block([msg.single_int64, @index0 + @index0], @index1 == 6)"), SELECT_NESTED( "msg.oneof_type.payload.single_int64 + msg.oneof_type.payload.single_int32 + " + "msg.oneof_type.payload.single_int64 + " @@ -445,37 +468,50 @@ private enum CseTestCase { + "cel.bind(@r1, @r0.single_int64, @r1 + @r0.single_int32 + @r1) + " + "msg.single_int64 + @r0.oneof_type.payload.single_int64) == 31", "cel.@block([msg.oneof_type.payload, @index0.single_int64], @index1 + @index0.single_int32" - + " + @index1 + msg.single_int64 + @index0.oneof_type.payload.single_int64 == 31)"), + + " + @index1 + msg.single_int64 + @index0.oneof_type.payload.single_int64 == 31)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.single_int64, @index1.oneof_type," + + " @index3.payload, @index4.single_int64, msg.single_int64, @index1.single_int32," + + " @index2 + @index7, @index8 + @index2, @index9 + @index6, @index10 + @index5]," + + " @index11 == 31)"), SELECT_NESTED_MESSAGE_MAP_INDEX_1( "msg.oneof_type.payload.map_int32_int64[1] + " + "msg.oneof_type.payload.map_int32_int64[1] + " + "msg.oneof_type.payload.map_int32_int64[1] == 15", "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64[1], @r0 + @r0 + @r0) == 15", "cel.@block([msg.oneof_type.payload.map_int32_int64[1]], @index0 + @index0 + @index0 ==" - + " 15)"), + + " 15)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.map_int32_int64, @index2[1], @index3" + + " + @index3, @index4 + @index3], @index5 == 15)"), SELECT_NESTED_MESSAGE_MAP_INDEX_2( "msg.oneof_type.payload.map_int32_int64[0] + " + "msg.oneof_type.payload.map_int32_int64[1] + " + "msg.oneof_type.payload.map_int32_int64[2] == 8", "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64, @r0[0] + @r0[1] + @r0[2]) == 8", "cel.@block([msg.oneof_type.payload.map_int32_int64], @index0[0] + @index0[1] + @index0[2]" - + " == 8)"), + + " == 8)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.map_int32_int64, @index2[2]," + + " @index2[1], @index2[0], @index5 + @index4, @index6 + @index3], @index7 == 8)"), TERNARY( "(msg.single_int64 > 0 ? msg.single_int64 : 0) == 3", "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? @r0 : 0) == 3", - "cel.@block([msg.single_int64], ((@index0 > 0) ? @index0 : 0) == 3)"), + "cel.@block([msg.single_int64], ((@index0 > 0) ? @index0 : 0) == 3)", + "cel.@block([msg.single_int64, @index0 > 0, @index1 ? @index0 : 0], @index2 == 3)"), TERNARY_BIND_RHS_ONLY( "false ? false : (msg.single_int64) + ((msg.single_int64 + 1) * 2) == 11", "false ? false : (cel.bind(@r0, msg.single_int64, @r0 + (@r0 + 1) * 2) == 11)", - "cel.@block([msg.single_int64], false ? false : (@index0 + (@index0 + 1) * 2 == 11))"), + "cel.@block([msg.single_int64], false ? false : (@index0 + (@index0 + 1) * 2 == 11))", + "cel.@block([msg.single_int64, @index0 + 1, @index1 * 2, @index0 + @index2, @index3 == 11]," + + " false ? false : @index4)"), NESTED_TERNARY( "(msg.single_int64 > 0 ? (msg.single_int32 > 0 ? " + "msg.single_int64 + msg.single_int32 : 0) : 0) == 8", "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? " + "cel.bind(@r1, msg.single_int32, (@r1 > 0) ? (@r0 + @r1) : 0) : 0) == 8", "cel.@block([msg.single_int64, msg.single_int32], ((@index0 > 0) ? ((@index1 > 0) ?" - + " (@index0 + @index1) : 0) : 0) == 8)"), - MULTIPLE_MACROS( + + " (@index0 + @index1) : 0) : 0) == 8)", + "cel.@block([msg.single_int64, msg.single_int32, @index0 + @index1, @index1 > 0, @index3 ?" + + " @index2 : 0, @index0 > 0, @index5 ? @index4 : 0], @index6 == 8)"), + MULTIPLE_MACROS_1( // Note that all of these have different iteration variables, but they are still logically // the same. "size([[1].exists(i, i > 0)]) + size([[1].exists(j, j > 0)]) + " @@ -483,66 +519,105 @@ private enum CseTestCase { "cel.bind(@r1, size([[2].exists(@c0:0, @c0:0 > 1)]), " + "cel.bind(@r0, size([[1].exists(@c0:0, @c0:0 > 0)]), @r0 + @r0) + @r1 + @r1) == 4", "cel.@block([size([[1].exists(@c0:0, @c0:0 > 0)]), size([[2].exists(@c0:0, @c0:0 > 1)])]," - + " @index0 + @index0 + @index1 + @index1 == 4)"), + + " @index0 + @index0 + @index1 + @index1 == 4)", + "cel.@block([[1], @c0:0 > 0, @x0:0 || @index1, [2], @c0:0 > 1, @x0:0 || @index4]," + + " size([@index0.exists(@c0:0, @index1)]) + size([@index0.exists(@c0:0, @index1)]) +" + + " size([@index3.exists(@c0:0, @index4)]) + size([@index3.exists(@c0:0, @index4)]) ==" + + " 4)"), MULTIPLE_MACROS_2( "[[1].exists(i, i > 0)] + [[1].exists(j, j > 0)] + [['a'].exists(k, k == 'a')] +" + " [['a'].exists(l, l == 'a')] == [true, true, true, true]", "cel.bind(@r1, [[\"a\"].exists(@c0:1, @c0:1 == \"a\")], cel.bind(@r0, [[1].exists(@c0:0," + " @c0:0 > 0)], @r0 + @r0) + @r1 + @r1) == [true, true, true, true]", "cel.@block([[[1].exists(@c0:0, @c0:0 > 0)], [[\"a\"].exists(@c0:1, @c0:1 == \"a\")]]," - + " @index0 + @index0 + @index1 + @index1 == [true, true, true, true])"), + + " @index0 + @index0 + @index1 + @index1 == [true, true, true, true])", + "cel.@block([[1], @c0:0 > 0, @x0:0 || @index1, [\"a\"], @c0:1 == \"a\", @x0:1 || @index4," + + " [true, true, true, true]], [@index0.exists(@c0:0, @index1)] +" + + " [@index0.exists(@c0:0, @index1)] + [@index3.exists(@c0:1, @index4)] +" + + " [@index3.exists(@c0:1, @index4)] == @index6)"), NESTED_MACROS( "[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]", "cel.bind(@r0, [1, 2, 3], @r0.map(@c0:0, @r0.map(@c1:0, @c1:0 + 1))) == " + "cel.bind(@r1, [2, 3, 4], [@r1, @r1, @r1])", "cel.@block([[1, 2, 3], [2, 3, 4]], @index0.map(@c0:0, @index0.map(@c1:0, @c1:0 + 1)) ==" - + " [@index1, @index1, @index1])"), + + " [@index1, @index1, @index1])", + "cel.@block([[1, 2, 3], [2, 3, 4], [@index1, @index1, @index1], @c1:0 + 1, [@index3], @x1:0" + + " + @index4], @index0.map(@c0:0, @index0.map(@c1:0, @index3)) == @index2)"), + NESTED_MACROS_2( + "[1, 2].map(y, [1, 2, 3].filter(x, x == y)) == [[1], [2]]", + "[1, 2].map(@c0:0, [1, 2, 3].filter(@c1:0, @c1:0 == @c0:0)) == [[1], [2]]", + "[1, 2].map(@c0:0, [1, 2, 3].filter(@c1:0, @c1:0 == @c0:0)) == [[1], [2]]", + "cel.@block([[2], [1], [@index1, @index0], [@c1:0], @x1:0 + @index3, @c1:0 == @c0:0," + + " @index5 ? @index4 : @x1:0, [1, 2, 3], [1, 2]], @index8.map(@c0:0," + + " @index7.filter(@c1:0, @index5)) == @index2)"), INCLUSION_LIST( "1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]", "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, 1 in @r0, @r1 && 2 in @r0 && 3 in [3, @r0] &&" + " @r1))", "cel.@block([[1, 2, 3], 1 in @index0], @index1 && 2 in @index0 && 3 in [3, @index0] &&" - + " @index1)"), + + " @index1)", + "cel.@block([[1, 2, 3], 1 in @index0, [3, @index0], 3 in @index2, @index3 && @index1, 2 in" + + " @index0, @index1 && @index5], @index6 && @index4)"), INCLUSION_MAP( "2 in {'a': 1, 2: {true: false}, 3: {true: false}}", "2 in cel.bind(@r0, {true: false}, {\"a\": 1, 2: @r0, 3: @r0})", - "cel.@block([{true: false}], 2 in {\"a\": 1, 2: @index0, 3: @index0})"), + "cel.@block([{true: false}], 2 in {\"a\": 1, 2: @index0, 3: @index0})", + "cel.@block([{true: false}, {\"a\": 1, 2: @index0, 3: @index0}], 2 in @index1)"), + MACRO_ITER_VAR_NOT_REFERENCED( + "[1,2].map(i, [1, 2].map(i, [3,4])) == [[[3, 4], [3, 4]], [[3, 4], [3, 4]]]", + "cel.bind(@r1, [3, 4], cel.bind(@r0, [1, 2], @r0.map(@c0:0, @r0.map(@c1:0, @r1))) ==" + + " cel.bind(@r2, [@r1, @r1], [@r2, @r2]))", + "cel.@block([[1, 2], [3, 4], [@index1, @index1]], @index0.map(@c0:0, @index0.map(@c1:0," + + " @index1)) == [@index2, @index2])", + "cel.@block([[1, 2], [3, 4], [@index1, @index1], [@index2, @index2], [@index1], @x1:0 +" + + " @index4], @index0.map(@c0:0, @index0.map(@c1:0, @index1)) == @index3)"), MACRO_SHADOWED_VARIABLE( "[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3", "cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0:0, @c0:0 - 1 > 3)" + " || @r1))", "cel.@block([x - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@c0:0, @c0:0 - 1 > 3) ||" - + " @index1)"), + + " @index1)", + "cel.@block([x - 1, @index0 > 3, @c0:0 - 1, @index2 > 3, @x0:0 || @index3, @index1 ?" + + " @index0 : 5, [@index5]], @index6.exists(@c0:0, @index3) || @index1)"), MACRO_SHADOWED_VARIABLE_2( "size([\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])) == 2", "size([\"foo\", \"bar\"].map(@c1:0, cel.bind(@r0, @c1:0 + @c1:0, [@r0, @r0]))" + ".map(@c0:0, cel.bind(@r1, @c0:0 + @c0:0, [@r1, @r1]))) == 2", - "cel.@block([@c1:0 + @c1:0, @c0:0 + @c0:0], " - + "size([\"foo\", \"bar\"].map(@c1:0, [@index0, @index0])" - + ".map(@c0:0, [@index1, @index1])) == 2)"), + "cel.@block([@c1:0 + @c1:0, @c0:0 + @c0:0], size([\"foo\", \"bar\"].map(@c1:0, [@index0," + + " @index0]).map(@c0:0, [@index1, @index1])) == 2)", + "cel.@block([@c1:0 + @c1:0, @c0:0 + @c0:0, [@index1, @index1], [@index2], @x0:0 + @index3," + + " [@index0, @index0], [@index5], @x1:0 + @index6, [\"foo\", \"bar\"]]," + + " size(@index8.map(@c1:0, @index5).map(@c0:0, @index2)) == 2)"), PRESENCE_TEST( "has({'a': true}.a) && {'a':true}['a']", "cel.bind(@r0, {\"a\": true}, has(@r0.a) && @r0[\"a\"])", - "cel.@block([{\"a\": true}], has(@index0.a) && @index0[\"a\"])"), + "cel.@block([{\"a\": true}], has(@index0.a) && @index0[\"a\"])", + "cel.@block([{\"a\": true}, @index0[\"a\"]], has(@index0.a) && @index1)"), PRESENCE_TEST_WITH_TERNARY( "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 : 0) == 10", "cel.bind(@r0, msg.oneof_type, has(@r0.payload) ? @r0.payload.single_int64 : 0) == 10", "cel.@block([msg.oneof_type], (has(@index0.payload) ? @index0.payload.single_int64 : 0) ==" - + " 10)"), + + " 10)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.single_int64], (has(@index0.payload)" + + " ? @index2 : 0) == 10)"), PRESENCE_TEST_WITH_TERNARY_2( "(has(msg.oneof_type.payload) ? msg.oneof_type.payload.single_int64 :" + " msg.oneof_type.payload.single_int64 * 0) == 10", "cel.bind(@r0, msg.oneof_type, cel.bind(@r1, @r0.payload.single_int64, has(@r0.payload) ?" + " @r1 : (@r1 * 0))) == 10", "cel.@block([msg.oneof_type, @index0.payload.single_int64], (has(@index0.payload) ? @index1" - + " : (@index1 * 0)) == 10)"), + + " : (@index1 * 0)) == 10)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.single_int64, @index2 * 0]," + + " (has(@index0.payload) ? @index2 : @index3) == 10)"), PRESENCE_TEST_WITH_TERNARY_3( "(has(msg.oneof_type.payload.single_int64) ? msg.oneof_type.payload.single_int64 :" + " msg.oneof_type.payload.single_int64 * 0) == 10", "cel.bind(@r0, msg.oneof_type.payload, cel.bind(@r1, @r0.single_int64," + " has(@r0.single_int64) ? @r1 : (@r1 * 0))) == 10", "cel.@block([msg.oneof_type.payload, @index0.single_int64], (has(@index0.single_int64) ?" - + " @index1 : (@index1 * 0)) == 10)"), + + " @index1 : (@index1 * 0)) == 10)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.single_int64, @index2 * 0]," + + " (has(@index1.single_int64) ? @index2 : @index3) == 10)"), /** * Input: * @@ -610,21 +685,41 @@ private enum CseTestCase { "cel.@block([msg.oneof_type, @index0.payload, @index1.map_string_string]," + " (has(msg.oneof_type) && has(@index0.payload) && has(@index1.single_int64)) ?" + " ((has(@index1.map_string_string) && has(@index2.key)) ? (@index2.key == \"A\") :" - + " false) : false)"), + + " false) : false)", + "cel.@block([msg.oneof_type, @index0.payload, @index1.map_string_string, @index2.key," + + " @index3 == \"A\"], (has(msg.oneof_type) && has(@index0.payload) &&" + + " has(@index1.single_int64)) ? ((has(@index1.map_string_string) && has(@index2.key))" + + " ? @index4 : false) : false)"), OPTIONAL_LIST( "[10, ?optional.none(), [?optional.none(), ?opt_x], [?optional.none(), ?opt_x]] == [10," + " [5], [5]]", - "cel.bind(@r0, [?optional.none(), ?opt_x], [10, ?optional.none(), @r0, @r0]) ==" - + " cel.bind(@r1, [5], [10, @r1, @r1])", - "cel.@block([[?optional.none(), ?opt_x], [5]], [10, ?optional.none(), @index0, @index0] ==" - + " [10, @index1, @index1])"), + "cel.bind(@r0, optional.none(), cel.bind(@r1, [?@r0, ?opt_x], [10, ?@r0, @r1, @r1])) ==" + + " cel.bind(@r2, [5], [10, @r2, @r2])", + "cel.@block([optional.none(), [?@index0, ?opt_x], [5]], [10, ?@index0, @index1, @index1] ==" + + " [10, @index2, @index2])", + "cel.@block([optional.none(), [?@index0, ?opt_x], [5], [10, @index2, @index2], [10," + + " ?@index0, @index1, @index1]], @index4 == @index3)"), OPTIONAL_MAP( "{?'hello': optional.of('hello')}['hello'] + {?'hello': optional.of('hello')}['hello'] ==" + " 'hellohello'", "cel.bind(@r0, {?\"hello\": optional.of(\"hello\")}[\"hello\"], @r0 + @r0) ==" + " \"hellohello\"", "cel.@block([{?\"hello\": optional.of(\"hello\")}[\"hello\"]], @index0 + @index0 ==" - + " \"hellohello\")"), + + " \"hellohello\")", + "cel.@block([optional.of(\"hello\"), {?\"hello\": @index0}, @index1[\"hello\"], @index2 +" + + " @index2], @index3 == \"hellohello\")"), + OPTIONAL_MAP_CHAINED( + "{?'key': optional.of('test')}[?'bogus'].or({'key': 'test'}[?'bogus']).orValue({'key':" + + " 'test'}['key']) == 'test'", + "cel.bind(@r0, {\"key\": \"test\"}, {?\"key\":" + + " optional.of(\"test\")}[?\"bogus\"].or(@r0[?\"bogus\"]).orValue(@r0[\"key\"])) ==" + + " \"test\"", + "cel.@block([{\"key\": \"test\"}], {?\"key\":" + + " optional.of(\"test\")}[?\"bogus\"].or(@index0[?\"bogus\"]).orValue(@index0[\"key\"])" + + " == \"test\")", + "cel.@block([{\"key\": \"test\"}, @index0[\"key\"], @index0[?\"bogus\"], {?\"key\":" + + " optional.of(\"test\")}, @index3[?\"bogus\"], @index4.or(@index2)," + + " @index5.orValue(@index1)], @index6 == \"test\")"), OPTIONAL_MESSAGE( "TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:" @@ -633,17 +728,23 @@ private enum CseTestCase { + "?single_int64: optional.ofNonZeroValue(1), ?single_int32: optional.of(4)}, " + "@r0.single_int32 + @r0.single_int64) == 5", "cel.@block([TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" - + " optional.of(4)}], @index0.single_int32 + @index0.single_int64 == 5)"), + + " optional.of(4)}], @index0.single_int32 + @index0.single_int64 == 5)", + "cel.@block([optional.ofNonZeroValue(1), optional.of(4), TestAllTypes{?single_int64:" + + " @index0, ?single_int32: @index1}, @index2.single_int64, @index2.single_int32," + + " @index4 + @index3], @index5 == 5)"), ; private final String source; private final String unparsedBind; private final String unparsedBlock; + private final String unparsedBlockFlattened; - CseTestCase(String source, String unparsedBind, String unparsedBlock) { + CseTestCase( + String source, String unparsedBind, String unparsedBlock, String unparsedBlockFlattened) { this.source = source; this.unparsedBind = unparsedBind; this.unparsedBlock = unparsedBlock; + this.unparsedBlockFlattened = unparsedBlockFlattened; } } @@ -737,30 +838,283 @@ public void cse_withCelBlock_macroMapUnpopulated(@TestParameter CseTestCase test } @Test - public void celBlock_nestedComprehension_iterVarReferencedAcrossComprehensions() + public void cse_withCelBlockFlattened_macroMapPopulated(@TestParameter CseTestCase testCase) throws Exception { - String nestedComprehension = - "[\"foo\"].map(x, [[\"bar\"], [x + x, x + x]] + [\"bar\"].map(y, [x + y, [\"baz\"].map(z," - + " [x + y + z, x + y, x + y + z])])) == [[[\"bar\"], [\"foofoo\", \"foofoo\"]," - + " [\"foobar\", [[\"foobarbaz\", \"foobar\", \"foobarbaz\"]]]]]"; CelOptimizer celOptimizer = newCseOptimizer( SubexpressionOptimizerOptions.newBuilder() .populateMacroCalls(true) .enableCelBlock(true) + .subexpressionMaxRecursionDepth(1) .build()); - CelAbstractSyntaxTree ast = CEL.compile(nestedComprehension).getAst(); + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat( + CEL.createProgram(optimizedAst) + .eval( + ImmutableMap.of( + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) + .isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsedBlockFlattened); + } + + @Test + public void cse_withVariousRecursionDepths_macroMapUnpopulated( + @TestParameter CseTestCase testCase, + @TestParameter({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10"}) Integer maxRecursionDepth) + throws Exception { + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(false) + .enableCelBlock(true) + .subexpressionMaxRecursionDepth(maxRecursionDepth) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(optimizedAst.getSource().getMacroCalls()).isEmpty(); + assertThat( + CEL.createProgram(optimizedAst) + .eval( + ImmutableMap.of( + "msg", TEST_ALL_TYPES_INPUT, "x", 5L, "opt_x", Optional.of(5L)))) + .isEqualTo(true); + } + + @Test + @TestParameters( + "{recursionDepth: 0, unparsed: 'true ||" + + " msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type.payload.single_int64" + + " == 1'}") + @TestParameters( + "{recursionDepth: 1, unparsed: 'cel.@block([msg.oneof_type, @index0.payload," + + " @index1.oneof_type, @index2.payload, @index3.oneof_type, @index4.payload," + + " @index5.oneof_type, @index6.payload, @index7.single_int64, @index8 == 1], true ||" + + " @index9)'}") + @TestParameters( + "{recursionDepth: 2, unparsed: 'cel.@block([msg.oneof_type.payload," + + " @index0.oneof_type.payload, @index1.oneof_type.payload, @index2.oneof_type.payload," + + " @index3.single_int64 == 1], true || @index4)'}") + @TestParameters( + "{recursionDepth: 3, unparsed: 'cel.@block([msg.oneof_type.payload.oneof_type," + + " @index0.payload.oneof_type.payload, @index1.oneof_type.payload.single_int64, @index2" + + " == 1], true || @index3)'}") + @TestParameters( + "{recursionDepth: 4, unparsed: 'cel.@block([msg.oneof_type.payload.oneof_type.payload," + + " @index0.oneof_type.payload.oneof_type.payload, @index1.single_int64 == 1], true ||" + + " @index2)'}") + @TestParameters( + "{recursionDepth: 5, unparsed:" + + " 'cel.@block([msg.oneof_type.payload.oneof_type.payload.oneof_type," + + " @index0.payload.oneof_type.payload.single_int64 == 1], true || @index1)'}") + @TestParameters( + "{recursionDepth: 6, unparsed:" + + " 'cel.@block([msg.oneof_type.payload.oneof_type.payload.oneof_type.payload," + + " @index0.oneof_type.payload.single_int64 == 1], true || @index1)'}") + @TestParameters( + "{recursionDepth: 7, unparsed:" + + " 'cel.@block([msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type," + + " @index0.payload.single_int64 == 1], true || @index1)'}") + @TestParameters( + "{recursionDepth: 8, unparsed:" + + " 'cel.@block([msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type.payload," + + " @index0.single_int64 == 1], true || @index1)'}") + @TestParameters( + "{recursionDepth: 9, unparsed:" + + " 'cel.@block([msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type.payload.single_int64," + + " @index0 == 1], true || @index1)'}") + @TestParameters( + "{recursionDepth: 10, unparsed:" + + " 'cel.@block([msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type.payload.single_int64" + + " == 1], true || @index0)'}") + public void noCommonSubexpr_withRecursionDepth_deeplyNestedSelect( + int recursionDepth, String unparsed) throws Exception { + String expression = + "true ||" + + " msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type.payload.single_int64" + + " == 1"; + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .subexpressionMaxRecursionDepth(recursionDepth) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(expression).getAst(); CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); - assertThat(CEL_UNPARSER.unparse(optimizedAst)) - .isEqualTo( - "cel.@block([@c0:0 + @c0:0, [\"bar\"], @c0:0 + @c1:0, @index2 + @c2:0]," - + " [\"foo\"].map(@c0:0, [@index1, [@index0, @index0]] + @index1.map(@c1:0," - + " [@index2, [\"baz\"].map(@c2:0, [@index3, @index2, @index3])])) == [[@index1," - + " [\"foofoo\", \"foofoo\"], [\"foobar\", [[\"foobarbaz\", \"foobar\"," - + " \"foobarbaz\"]]]]])"); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(unparsed); + } + + @Test + @TestParameters( + "{recursionDepth: 1, unparsed: 'cel.@block([msg.oneof_type, @index0.payload," + + " @index1.oneof_type, @index2.payload, @index3.oneof_type, @index4.child," + + " @index5.child, @index6.payload, @index7.single_bool, @index4.payload," + + " @index9.oneof_type, @index10.payload, @index11.single_bool, true || @index12]," + + " @index13 || @index8)'}") + @TestParameters( + "{recursionDepth: 2, unparsed: 'cel.@block([msg.oneof_type, @index0.payload," + + " @index1.oneof_type, @index2.payload, @index3.oneof_type, @index4.child.child," + + " @index5.payload.single_bool, @index4.payload.oneof_type, @index7.payload.single_bool," + + " true || @index8], @index9 || @index6)'}") + @TestParameters( + "{recursionDepth: 3, unparsed: 'cel.@block([msg.oneof_type, @index0.payload," + + " @index1.oneof_type, @index2.payload, @index3.oneof_type, @index4.child.child.payload," + + " @index5.single_bool, @index4.payload.oneof_type.payload, true ||" + + " @index7.single_bool], @index8 || @index6)'}") + @TestParameters( + "{recursionDepth: 4, unparsed: 'cel.@block([msg.oneof_type, @index0.payload," + + " @index1.oneof_type, @index2.payload, @index3.oneof_type," + + " @index4.child.child.payload.single_bool," + + " @index4.payload.oneof_type.payload.single_bool, true || @index6], @index7 ||" + + " @index5)'}") + public void cse_withRecursionDepth_deeplyNestedSelect(int recursionDepth, String unparsed) + throws Exception { + String expression = + "true ||" + + " msg.oneof_type.payload.oneof_type.payload.oneof_type.payload.oneof_type.payload.single_bool" + + " || msg.oneof_type.payload.oneof_type.payload.oneof_type.child.child.payload.single_bool"; + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .subexpressionMaxRecursionDepth(recursionDepth) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(expression).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(unparsed); + } + + @Test + @TestParameters( + "{recursionDepth: 0, unparsed: '\"hello world\".matches(\"h\" + \"e\" + \"l\" + \"l\" +" + + " \"o\") == true'}") + @TestParameters( + "{recursionDepth: 1, unparsed: 'cel.@block([\"h\" + \"e\", @index0 + \"l\", @index1 + \"l\"," + + " @index2 + \"o\", \"hello world\".matches(@index3)], @index4 == true)'}") + @TestParameters( + "{recursionDepth: 2, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\", @index0 + \"l\" + \"o\"," + + " \"hello world\".matches(@index1)], @index2 == true)'}") + @TestParameters( + "{recursionDepth: 3, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\" + \"l\", \"hello" + + " world\".matches(@index0 + \"o\")], @index1 == true)'}") + @TestParameters( + "{recursionDepth: 4, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\" + \"l\" + \"o\", \"hello" + + " world\".matches(@index0)], @index1 == true)'}") + @TestParameters( + "{recursionDepth: 5, unparsed: 'cel.@block([\"hello world\".matches(\"h\" + \"e\" + \"l\" +" + + " \"l\" + \"o\")], @index0 == true)'}") + public void noCommonSubexpr_withRecursionDepth_deeplyNestedCallOnArgs( + int recursionDepth, String unparsed) throws Exception { + String expression = "'hello world'.matches('h' + 'e' + 'l' + 'l' + 'o') == true"; + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .subexpressionMaxRecursionDepth(recursionDepth) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(expression).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(unparsed); + } + + @Test + @TestParameters( + "{recursionDepth: 0, unparsed: '(\"h\" + \"e\" + \"l\" + \"l\" + \"o\" + \"" + + " world\").matches(\"hello\") == true'}") + @TestParameters( + "{recursionDepth: 1, unparsed: 'cel.@block([\"h\" + \"e\", @index0 + \"l\", @index1 + \"l\"," + + " @index2 + \"o\", @index3 + \" world\", @index4.matches(\"hello\")], @index5 ==" + + " true)'}") + @TestParameters( + "{recursionDepth: 2, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\", @index0 + \"l\" + \"o\"," + + " (@index1 + \" world\").matches(\"hello\")], @index2 == true)'}") + @TestParameters( + "{recursionDepth: 3, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\" + \"l\", (@index0 + \"o\" +" + + " \" world\").matches(\"hello\")], @index1 == true)'}") + @TestParameters( + "{recursionDepth: 4, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\" + \"l\" + \"o\", (@index0 +" + + " \" world\").matches(\"hello\")], @index1 == true)'}") + @TestParameters( + "{recursionDepth: 5, unparsed: 'cel.@block([\"h\" + \"e\" + \"l\" + \"l\" + \"o\" + \"" + + " world\", @index0.matches(\"hello\")], @index1 == true)'}") + @TestParameters( + "{recursionDepth: 6, unparsed: 'cel.@block([(\"h\" + \"e\" + \"l\" + \"l\" + \"o\" + \"" + + " world\").matches(\"hello\")], @index0 == true)'}") + public void noCommonSubexpr_withRecursionDepth_deeplyNestedCallOnTarget( + int recursionDepth, String unparsed) throws Exception { + String expression = "('h' + 'e' + 'l' + 'l' + 'o' + ' world').matches('hello') == true"; + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .subexpressionMaxRecursionDepth(recursionDepth) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(expression).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(unparsed); + } + + @Test + @TestParameters( + "{recursionDepth: 1, unparsed: 'cel.@block([\"w\" + \"o\", @index0 + \"r\", @index1 + \"l\"," + + " @index2 + \"d\", \"h\" + \"e\", @index4 + \"l\", @index5 + \"l\", @index6 + \"o\"," + + " @index7 + \" world\", @index8.matches(@index3)], @index9 == true)'}") + @TestParameters( + "{recursionDepth: 2, unparsed: 'cel.@block([\"w\" + \"o\" + \"r\", @index0 + \"l\" + \"d\"," + + " \"h\" + \"e\" + \"l\", @index2 + \"l\" + \"o\", (@index3 + \"" + + " world\").matches(@index1)], @index4 == true)'}") + @TestParameters( + "{recursionDepth: 3, unparsed: 'cel.@block([\"w\" + \"o\" + \"r\" + \"l\", @index0 + \"d\"," + + " \"h\" + \"e\" + \"l\" + \"l\", (@index2 + \"o\" + \" world\").matches(@index1)]," + + " @index3 == true)'}") + @TestParameters( + "{recursionDepth: 4, unparsed: 'cel.@block([\"w\" + \"o\" + \"r\" + \"l\" + \"d\", \"h\" +" + + " \"e\" + \"l\" + \"l\" + \"o\", (@index1 + \" world\").matches(@index0)], @index2 ==" + + " true)'}") + @TestParameters( + "{recursionDepth: 5, unparsed: 'cel.@block([\"w\" + \"o\" + \"r\" + \"l\" + \"d\", \"h\" +" + + " \"e\" + \"l\" + \"l\" + \"o\" + \" world\", @index1.matches(@index0)], @index2 ==" + + " true)'}") + @TestParameters( + "{recursionDepth: 6, unparsed: 'cel.@block([(\"h\" + \"e\" + \"l\" + \"l\" + \"o\" + \"" + + " world\").matches(\"w\" + \"o\" + \"r\" + \"l\" + \"d\")], @index0 == true)'}") + public void noCommonSubexpr_withRecursionDepth_deeplyNestedCallOnBothTargetAndArgs( + int recursionDepth, String unparsed) throws Exception { + String expression = + "('h' + 'e' + 'l' + 'l' + 'o' + ' world').matches('w' + 'o' + 'r' + 'l' + 'd') == true"; + CelOptimizer celOptimizer = + newCseOptimizer( + SubexpressionOptimizerOptions.newBuilder() + .populateMacroCalls(true) + .enableCelBlock(true) + .subexpressionMaxRecursionDepth(recursionDepth) + .build()); + CelAbstractSyntaxTree ast = CEL.compile(expression).getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(unparsed); } @Test @@ -781,25 +1135,35 @@ public void cse_resultTypeSet_celBlockOptimizationSuccess() throws Exception { .isEqualTo("cel.@block([size(\"a\")], @index0 + @index0 == 2)"); } + private enum CseNoOpTestCase { + // Nothing to optimize + NO_COMMON_SUBEXPR("size(\"hello\")"), + // Constants and identifiers + INT_CONST_ONLY("2 + 2 + 2 + 2"), + IDENT_ONLY("x + x + x + x"), + BOOL_CONST_ONLY("true == true && false == false"), + // Constants and identifiers within a function + CONST_WITHIN_FUNCTION("size(\"hello\" + \"hello\" + \"hello\")"), + IDENT_WITHIN_FUNCTION("string(x + x + x)"), + // Non-standard functions are considered non-pure for time being + NON_STANDARD_FUNCTION_1("custom_func(1) + custom_func(1)"), + NON_STANDARD_FUNCTION_2("1 + custom_func(1) + 1 + custom_func(1)"), + // Duplicated but nested calls. + NESTED_FUNCTION("int(timestamp(int(timestamp(1000000000))))"), + // This cannot be optimized. Extracting the common subexpression would presence test + // the bound identifier (e.g: has(@r0)), which is not valid. + UNOPTIMIZABLE_TERNARY("has(msg.single_any) ? msg.single_any : 10"); + + private final String source; + + CseNoOpTestCase(String source) { + this.source = source; + } + } + @Test - // Nothing to optimize - @TestParameters("{source: 'size(\"hello\")'}") - // Constants and identifiers - @TestParameters("{source: '2 + 2 + 2 + 2'}") - @TestParameters("{source: 'x + x + x + x'}") - @TestParameters("{source: 'true == true && false == false'}") - // Constants and identifiers within a function - @TestParameters("{source: 'size(\"hello\" + \"hello\" + \"hello\")'}") - @TestParameters("{source: 'string(x + x + x)'}") - // Non-standard functions are considered non-pure for time being - @TestParameters("{source: 'custom_func(1) + custom_func(1)'}") - // Duplicated but nested calls. - @TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}") - // This cannot be optimized. Extracting the common subexpression would presence test - // the bound identifier (e.g: has(@r0)), which is not valid. - @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") - public void cse_withCelBind_noop(String source) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + public void cse_withCelBind_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer( @@ -810,28 +1174,12 @@ public void cse_withCelBind_noop(String source) throws Exception { .optimize(ast); assertThat(ast.getExpr()).isEqualTo(optimizedAst.getExpr()); - assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.source); } @Test - // Nothing to optimize - @TestParameters("{source: 'size(\"hello\")'}") - // Constants and identifiers - @TestParameters("{source: '2 + 2 + 2 + 2'}") - @TestParameters("{source: 'x + x + x + x'}") - @TestParameters("{source: 'true == true && false == false'}") - // Constants and identifiers within a function - @TestParameters("{source: 'size(\"hello\" + \"hello\" + \"hello\")'}") - @TestParameters("{source: 'string(x + x + x)'}") - // Non-standard functions are considered non-pure for time being - @TestParameters("{source: 'custom_func(1) + custom_func(1)'}") - // Duplicated but nested calls. - @TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}") - // This cannot be optimized. Extracting the common subexpression would presence test - // the bound identifier (e.g: has(@r0)), which is not valid. - @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") - public void cse_withCelBlock_noop(String source) throws Exception { - CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + public void cse_withCelBlock_noop(@TestParameter CseNoOpTestCase testCase) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); CelAbstractSyntaxTree optimizedAst = newCseOptimizer( @@ -842,7 +1190,7 @@ public void cse_withCelBlock_noop(String source) throws Exception { .optimize(ast); assertThat(ast.getExpr()).isEqualTo(optimizedAst.getExpr()); - assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.source); } @Test @@ -1024,9 +1372,10 @@ public void cse_withCelBind_largeNestedMacro() throws Exception { assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo( - "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, size(@r0.map(i, @r0.map(i, @r0.map(i," - + " @r0.map(i, @r0.map(i, @r0.map(i, @r0.map(i, @r0.map(i, @r0))))))))), @r1 + @r1" - + " + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1))"); + "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, size(@r0.map(@c0:0, @r0.map(@c1:0," + + " @r0.map(@c2:0, @r0.map(@c3:0, @r0.map(@c4:0, @r0.map(@c5:0, @r0.map(@c6:0," + + " @r0.map(@c7:0, @r0))))))))), @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 +" + + " @r1))"); assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(27); } @@ -1060,10 +1409,10 @@ public void cse_withCelBlock_largeNestedMacro() throws Exception { assertThat(CEL_UNPARSER.unparse(optimizedAst)) .isEqualTo( - "cel.@block([[1, 2, 3], size(@index0.map(i, @index0.map(i, @index0.map(i," - + " @index0.map(i, @index0.map(i, @index0.map(i, @index0.map(i, @index0.map(i," - + " @index0)))))))))], @index1 + @index1 + @index1 + @index1 + @index1 + @index1 +" - + " @index1 + @index1 + @index1)"); + "cel.@block([[1, 2, 3], size(@index0.map(@c0:0, @index0.map(@c1:0, @index0.map(@c2:0," + + " @index0.map(@c3:0, @index0.map(@c4:0, @index0.map(@c5:0, @index0.map(@c6:0," + + " @index0.map(@c7:0, @index0)))))))))], @index1 + @index1 + @index1 + @index1 +" + + " @index1 + @index1 + @index1 + @index1 + @index1)"); assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(27); }