From 790e8cfc657e6f94371c4f77af42a013faeabdb2 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Thu, 18 Jan 2024 19:45:24 -0800 Subject: [PATCH] Implement Optimizer for Common Subexpression Elimination PiperOrigin-RevId: 599696579 --- .gitignore | 2 + .../common/ast/CelExprIdGeneratorFactory.java | 5 + .../common/navigation/CelNavigableExpr.java | 4 + optimizer/optimizers/BUILD.bazel | 6 + .../dev/cel/optimizer/CelAstOptimizer.java | 32 + .../java/dev/cel/optimizer/MutableAst.java | 59 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 21 + .../optimizers/ConstantFoldingOptimizer.java | 5 +- .../optimizers/SubexpressionOptimizer.java | 321 ++++++++++ .../dev/cel/optimizer/MutableAstTest.java | 19 +- .../dev/cel/optimizer/optimizers/BUILD.bazel | 6 + .../ConstantFoldingOptimizerTest.java | 92 +++ .../SubexpressionOptimizerTest.java | 564 ++++++++++++++++++ 13 files changed, 1111 insertions(+), 25 deletions(-) create mode 100644 optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java diff --git a/.gitignore b/.gitignore index e8ebf31c..753e6967 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ bazel-cel-java bazel-out bazel-testlogs +MODULE.bazel* + # IntelliJ IDEA .idea *.iml diff --git a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java index 10e259c0..9b7bc60f 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java @@ -64,6 +64,11 @@ public boolean hasId(long id) { return idSet.containsKey(id); } + /** Generates the next available ID. */ + public long nextExprId() { + return ++exprId; + } + /** * Generate the next available ID while memoizing the existing ID. * diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java index 76fa5bfb..326f8b78 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java @@ -50,6 +50,10 @@ public enum TraversalOrder { public abstract CelExpr expr(); + public long id() { + return expr().id(); + } + public abstract Optional parent(); /** Represents the count of transitive parents. Depth of an AST's root is 0. */ diff --git a/optimizer/optimizers/BUILD.bazel b/optimizer/optimizers/BUILD.bazel index c6cdf913..e39612db 100644 --- a/optimizer/optimizers/BUILD.bazel +++ b/optimizer/optimizers/BUILD.bazel @@ -7,3 +7,9 @@ java_library( name = "constant_folding", exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:constant_folding"], ) + +java_library( + name = "common_subexpression_elimination", + visibility = ["//visibility:public"], # TODO: Expose when ready + exports = ["//optimizer/src/main/java/dev/cel/optimizer/optimizers:common_subexpression_elimination"], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index 652dca0e..9a9f22c9 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -47,4 +47,36 @@ default CelAbstractSyntaxTree replaceSubtree( CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { return MutableAst.replaceSubtree(ast, newExpr, exprIdToReplace); } + + /** + * Generates a new bind macro using the provided initialization and result expression, then + * replaces the subtree using the new bind expr at the designated expr ID. + * + *

The bind call takes the format of: {@code cel.bind(varInit, varName, resultExpr)} + * + * @param ast Original ast to mutate. + * @param varName New variable name for the bind macro call. + * @param varInit Initialization expression to bind to the local variable. + * @param resultExpr Result expression + * @param exprIdToReplace Expression ID of the subtree that is getting replaced. + */ + default CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( + CelAbstractSyntaxTree ast, + String varName, + CelExpr varInit, + CelExpr resultExpr, + long exprIdToReplace) { + return MutableAst.replaceSubtreeWithNewBindMacro( + ast, varName, varInit, resultExpr, exprIdToReplace); + } + + /** Sets all expr IDs in the expression tree to 0. */ + default CelExpr clearExprIds(CelExpr celExpr) { + return MutableAst.clearExprIds(celExpr); + } + + /** Renumbers all the expr IDs in the given AST in a consecutive manner starting from 1. */ + default CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) { + return MutableAst.renumberIdsConsecutively(ast); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index c8bf99c7..f1f57925 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -31,6 +31,7 @@ import dev.cel.common.ast.CelExpr.CelCreateMap; import dev.cel.common.ast.CelExpr.CelCreateStruct; import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.ast.CelExprFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator; @@ -42,7 +43,7 @@ /** MutableAst contains logic for mutating a {@link CelExpr}. */ @Internal final class MutableAst { - private static final int MAX_ITERATION_COUNT = 500; + private static final int MAX_ITERATION_COUNT = 1000; private final CelExpr.Builder newExpr; private final ExprIdGenerator celExprIdGenerator; private int iterationCount; @@ -91,7 +92,8 @@ static CelAbstractSyntaxTree replaceSubtree( // Mutate the AST root with the new subtree. All the existing expr IDs are renumbered in the // process, but its original IDs are memoized so that we can normalize the expr IDs // in the macro source map. - StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); + StableIdGenerator stableIdGenerator = + CelExprIdGeneratorFactory.newStableIdGenerator(getMaxId(newAst)); CelExpr.Builder mutatedRoot = replaceSubtreeImpl( stableIdGenerator::renumberId, @@ -147,13 +149,26 @@ static CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( ast, CelAbstractSyntaxTree.newParsedAst(bindMacro.bindExpr(), celSource), exprIdToReplace); } + static CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) { + StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); + CelExpr.Builder root = + renumberExprIds(stableIdGenerator::renumberId, ast.getExpr().toBuilder()); + CelSource newSource = + normalizeMacroSource( + ast.getSource(), Integer.MIN_VALUE, root, stableIdGenerator::renumberId); + + return CelAbstractSyntaxTree.newParsedAst(root.build(), newSource); + } + private static BindMacro newBindMacro( String varName, CelExpr varInit, CelExpr resultExpr, StableIdGenerator stableIdGenerator) { - // Clear incoming expression IDs in the initialization expression to avoid collision with the - // main AST. - varInit = clearExprIds(varInit); + // Renumber incoming expression IDs in the init and result expression to avoid collision with + // the main AST. Existing IDs are memoized for a macro source sanitization pass at the end + // (e.g: inserting a bind macro to an existing macro expr) + varInit = renumberExprIds(stableIdGenerator::nextExprId, varInit.toBuilder()).build(); + resultExpr = renumberExprIds(stableIdGenerator::nextExprId, resultExpr.toBuilder()).build(); CelExprFactory exprFactory = - CelExprFactory.newInstance((unused) -> stableIdGenerator.nextExprId(-1)); + CelExprFactory.newInstance((unused) -> stableIdGenerator.nextExprId()); CelExpr bindMacroExpr = exprFactory.fold( "#unused", @@ -164,17 +179,12 @@ private static BindMacro newBindMacro( exprFactory.newIdentifier(varName), resultExpr); - // Update the IDs in the new expression tree first. This ensures that no ID collision - // occurs while attempting to replace the subtree later, potentially leading to an infinite loop - bindMacroExpr = - renumberExprIds(stableIdGenerator::nextExprId, bindMacroExpr.toBuilder()).build(); - CelExpr bindMacroCallExpr = exprFactory .newReceiverCall( "bind", - CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(-1), "cel"), - CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(-1), varName), + CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(), "cel"), + CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(), varName), bindMacroExpr.comprehension().accuInit(), bindMacroExpr.comprehension().result()) .toBuilder() @@ -270,6 +280,7 @@ private static CelSource normalizeMacroSource( if (!allExprs.containsKey(callChild.id())) { continue; } + CelExpr mutatedExpr = allExprs.get(callChild.id()); if (!callChild.equals(mutatedExpr)) { newCall = @@ -279,6 +290,25 @@ private static CelSource normalizeMacroSource( sourceBuilder.addMacroCalls(callId, newCall.build()); } + // Replace comprehension nodes with a NOT_SET reference to reduce AST size. + for (Entry macroCall : sourceBuilder.getMacroCalls().entrySet()) { + CelExpr macroCallExpr = macroCall.getValue(); + CelNavigableExpr.fromExpr(macroCallExpr) + .allNodes() + .filter(node -> node.getKind().equals(Kind.COMPREHENSION)) + .map(CelNavigableExpr::expr) + .forEach( + node -> { + CelExpr.Builder mutatedNode = + replaceSubtreeImpl( + (id) -> id, + macroCallExpr.toBuilder(), + CelExpr.ofNotSet(node.id()).toBuilder(), + node.id()); + macroCall.setValue(mutatedNode.build()); + }); + } + return sourceBuilder.build(); } @@ -309,7 +339,7 @@ private static long getMaxId(CelAbstractSyntaxTree ast) { private static long getMaxId(CelExpr newExpr) { return CelNavigableExpr.fromExpr(newExpr) .allNodes() - .mapToLong(node -> node.expr().id()) + .mapToLong(CelNavigableExpr::id) .max() .orElseThrow(NoSuchElementException::new); } @@ -419,7 +449,6 @@ private CelExpr.Builder visit(CelExpr.Builder expr, CelComprehension.Builder com */ @AutoValue abstract static class BindMacro { - /** Comprehension expr for the generated cel.bind macro. */ abstract CelExpr bindExpr(); diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 730f0c0d..f2519c83 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -27,3 +27,24 @@ java_library( "@maven//:com_google_guava_guava", ], ) + +java_library( + name = "common_subexpression_elimination", + srcs = [ + "SubexpressionOptimizer.java", + ], + tags = [ + ], + deps = [ + "//:auto_value", + "//bundle:cel", + "//checker:checker_legacy_environment", + "//common", + "//common/ast", + "//common/navigation", + "//optimizer:ast_optimizer", + "//parser:operator", + "@maven//:com_google_guava_guava", + "@maven//:org_jspecify_jspecify", + ], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index 460654bb..272cd54f 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -98,9 +98,8 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) // If the output is a list, map, or struct which contains optional entries, then prune it // to make sure that the optionals, if resolved, do not surface in the output literal. - navigableAst = CelNavigableAst.fromAst(pruneOptionalElements(navigableAst)); - - return navigableAst.getAst(); + CelAbstractSyntaxTree newAst = pruneOptionalElements(navigableAst); + return renumberIdsConsecutively(newAst); } private static boolean canFold(CelNavigableExpr navigableExpr) { diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java new file mode 100644 index 00000000..88ae694e --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -0,0 +1,321 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer.optimizers; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Arrays.stream; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import dev.cel.bundle.Cel; +import dev.cel.checker.Standard; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelSource; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelIdent; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; +import dev.cel.common.navigation.CelNavigableAst; +import dev.cel.common.navigation.CelNavigableExpr; +import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder; +import dev.cel.optimizer.CelAstOptimizer; +import dev.cel.parser.Operator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.stream.Stream; + +/** + * Performs Common Subexpression Elimination. + * + *

+ * Subexpressions are extracted into `cel.bind` calls. For example, the expression below:
+ *
+ * {@code
+ *    message.child.text_map[x].startsWith("hello") && message.child.text_map[x].endsWith("world")
+ * }
+ *
+ * will be optimized into the following form:
+ *
+ * {@code
+ *    cel.bind(@r0, message.child.text_map[x],
+ *        @r0.startsWith("hello") && @r0.endsWith("world"))
+ * }
+ * 
+ */ +public class SubexpressionOptimizer implements CelAstOptimizer { + + private static final SubexpressionOptimizer INSTANCE = + new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); + private static final String BIND_IDENTIFIER_PREFIX = "@r"; + private static final ImmutableSet CSE_ALLOWED_FUNCTIONS = + Streams.concat( + stream(Operator.values()).map(Operator::getFunction), + stream(Standard.Function.values()).map(Standard.Function::getFunction)) + .collect(toImmutableSet()); + private final SubexpressionOptimizerOptions cseOptions; + + /** + * Returns a default instance of common subexpression elimination optimizer with preconfigured + * defaults. + */ + public static SubexpressionOptimizer getInstance() { + return INSTANCE; + } + + /** + * Returns a new instance of common subexpression elimination optimizer configured with the + * provided {@link SubexpressionOptimizerOptions}. + */ + public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions cseOptions) { + return new SubexpressionOptimizer(cseOptions); + } + + @Override + public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { + CelAbstractSyntaxTree astToModify = navigableAst.getAst(); + CelSource sourceToModify = astToModify.getSource(); + int bindIdentifierIndex = 0; + int iterCount; + for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) { + CelNavigableExpr cseCandidate = findCseCandidate(astToModify).orElse(null); + if (cseCandidate == null) { + break; + } + + String bindIdentifier = BIND_IDENTIFIER_PREFIX + bindIdentifierIndex; + bindIdentifierIndex++; + + // Using the CSE candidate, fetch all semantically equivalent subexpressions ahead of time. + ImmutableList allCseCandidates = + getAllCseCandidatesStream(astToModify, cseCandidate.expr()).collect(toImmutableList()); + + // Replace all CSE candidates with new bind identifier + for (CelExpr semanticallyEqualNode : allCseCandidates) { + iterCount++; + // Refetch the candidate expr as mutating the AST could have renumbered its IDs. + CelExpr exprToReplace = + getAllCseCandidatesStream(astToModify, semanticallyEqualNode) + .findAny() + .orElseThrow( + () -> + new NoSuchElementException( + "No value present for expr ID: " + semanticallyEqualNode.id())); + + astToModify = + replaceSubtree( + astToModify, + CelExpr.newBuilder() + .setIdent(CelIdent.newBuilder().setName(bindIdentifier).build()) + .build(), + exprToReplace.id()); + } + + // Find LCA to insert the new cel.bind macro into. + CelNavigableExpr lca = getLca(astToModify, bindIdentifier); + + sourceToModify = + sourceToModify.toBuilder() + .addAllMacroCalls(astToModify.getSource().getMacroCalls()) + .build(); + astToModify = CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), sourceToModify); + + // Insert the new bind call + astToModify = + replaceSubtreeWithNewBindMacro( + astToModify, bindIdentifier, cseCandidate.expr(), lca.expr(), lca.id()); + + // Retain the existing macro calls in case if the bind identifiers are replacing a subtree + // that contains a comprehension. + sourceToModify = astToModify.getSource(); + } + + if (iterCount >= cseOptions.maxIterationLimit()) { + throw new IllegalStateException("Max iteration count reached."); + } + + if (iterCount == 0) { + // No modification has been made. + return astToModify; + } + + if (!cseOptions.populateMacroCalls()) { + astToModify = + CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build()); + } + + return renumberIdsConsecutively(astToModify); + } + + private Stream getAllCseCandidatesStream( + CelAbstractSyntaxTree ast, CelExpr cseCandidate) { + return CelNavigableAst.fromAst(ast) + .getRoot() + .allNodes() + .filter(SubexpressionOptimizer::canEliminate) + .map(CelNavigableExpr::expr) + .filter(expr -> areSemanticallyEqual(cseCandidate, expr)); + } + + private static CelNavigableExpr getLca(CelAbstractSyntaxTree ast, String boundIdentifier) { + CelNavigableExpr root = CelNavigableAst.fromAst(ast).getRoot(); + ImmutableList allNodesWithIdentifier = + root.allNodes() + .filter(node -> node.expr().identOrDefault().name().equals(boundIdentifier)) + .collect(toImmutableList()); + + if (allNodesWithIdentifier.size() < 2) { + throw new IllegalStateException("Expected at least 2 bound identifiers to be present."); + } + + CelNavigableExpr lca = root; + long lcaAncestorCount = 0; + HashMap ancestors = new HashMap<>(); + for (CelNavigableExpr navigableExpr : allNodesWithIdentifier) { + Optional maybeParent = Optional.of(navigableExpr); + while (maybeParent.isPresent()) { + CelNavigableExpr parent = maybeParent.get(); + if (!ancestors.containsKey(parent.id())) { + ancestors.put(parent.id(), 1L); + continue; + } + + long ancestorCount = ancestors.get(parent.id()); + if (lcaAncestorCount < ancestorCount + || (lcaAncestorCount == ancestorCount && lca.depth() < parent.depth())) { + lca = parent; + lcaAncestorCount = ancestorCount; + } + + ancestors.put(parent.id(), ancestorCount + 1); + maybeParent = parent.parent(); + } + } + + return lca; + } + + private Optional findCseCandidate(CelAbstractSyntaxTree ast) { + HashSet encounteredNodes = new HashSet<>(); + ImmutableList allNodes = + CelNavigableAst.fromAst(ast) + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .filter(SubexpressionOptimizer::canEliminate) + .collect(toImmutableList()); + + for (CelNavigableExpr node : allNodes) { + // Strip out all IDs to test equivalence + CelExpr celExpr = clearExprIds(node.expr()); + if (encounteredNodes.contains(celExpr)) { + return Optional.of(node); + } + + encounteredNodes.add(celExpr); + } + + return Optional.empty(); + } + + private static boolean canEliminate(CelNavigableExpr navigableExpr) { + return !navigableExpr.getKind().equals(Kind.CONSTANT) + && !navigableExpr.getKind().equals(Kind.IDENT) + && !navigableExpr.expr().identOrDefault().name().startsWith(BIND_IDENTIFIER_PREFIX) + && isAllowedFunction(navigableExpr) + && isWithinInlineableComprehension(navigableExpr); + } + + private static boolean isWithinInlineableComprehension(CelNavigableExpr expr) { + Optional maybeParent = expr.parent(); + while (maybeParent.isPresent()) { + CelNavigableExpr parent = maybeParent.get(); + if (parent.getKind().equals(Kind.COMPREHENSION)) { + return Streams.concat( + // If the expression is within a comprehension, it is eligible for CSE iff is in + // result or iterRange. While result is not human authored, it needs to be included + // to extract subexpressions that are already in cel.bind macro. + CelNavigableExpr.fromExpr(parent.expr().comprehension().result()).descendants(), + CelNavigableExpr.fromExpr(parent.expr().comprehension().iterRange()).allNodes()) + .filter( + node -> + // Exclude empty lists (cel.bind sets this for iterRange). + !node.getKind().equals(Kind.CREATE_LIST) + || !node.expr().createList().elements().isEmpty()) + .map(CelNavigableExpr::expr) + .anyMatch(node -> node.equals(expr.expr())); + } + maybeParent = parent.parent(); + } + + return true; + } + + private boolean areSemanticallyEqual(CelExpr expr1, CelExpr expr2) { + return clearExprIds(expr1).equals(clearExprIds(expr2)); + } + + private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) { + if (navigableExpr.getKind().equals(Kind.CALL)) { + return CSE_ALLOWED_FUNCTIONS.contains(navigableExpr.expr().call().function()); + } + + return true; + } + + /** Options to configure how Common Subexpression Elimination behave. */ + @AutoValue + public abstract static class SubexpressionOptimizerOptions { + public abstract int maxIterationLimit(); + + public abstract boolean populateMacroCalls(); + + /** Builder for configuring the {@link SubexpressionOptimizerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + + /** + * Limit the number of iteration while performing CSE. An exception is thrown if the iteration + * count exceeds the set value. + */ + public abstract Builder maxIterationLimit(int value); + + /** + * Populate the macro_calls map in source_info with macro calls on the resulting optimized + * AST. + */ + public abstract Builder populateMacroCalls(boolean value); + + public abstract SubexpressionOptimizerOptions build(); + + Builder() {} + } + + /** Returns a new options builder with recommended defaults pre-configured. */ + public static Builder newBuilder() { + return new AutoValue_SubexpressionOptimizer_SubexpressionOptimizerOptions.Builder() + .maxIterationLimit(500) + .populateMacroCalls(false); + } + + SubexpressionOptimizerOptions() {} + } + + private SubexpressionOptimizer(SubexpressionOptimizerOptions cseOptions) { + this.cseOptions = cseOptions; + } +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index b7f1d29d..587c3c71 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -71,7 +71,7 @@ public void constExpr() throws Exception { ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); assertThat(mutatedAst.getExpr()) - .isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(true))); + .isEqualTo(CelExpr.ofConstantExpr(3, CelConstant.ofValue(true))); } @Test @@ -126,7 +126,7 @@ public void replaceSubtree_rootReplacedWithMacro_macroCallPopulated( CelAbstractSyntaxTree ast2 = CEL.compile(source).getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().expr().id()); + MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().id()); assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(expectedMacroCallSize); assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo(source); @@ -185,7 +185,7 @@ public void replaceSubtreeWithNewBindMacro_replaceRoot() throws Exception { variableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), resultExpr, - CelNavigableAst.fromAst(ast).getRoot().expr().id()); + CelNavigableAst.fromAst(ast).getRoot().id()); assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(1); assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("cel.bind(@r0, 3, @r0 + @r0)"); @@ -333,7 +333,7 @@ public void replaceSubtreeWithNewBindMacro_replaceRootWithNestedBindMacro() thro nestedVariableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), resultExpr, - 1); + CelNavigableAst.fromAst(mutatedAst).getRoot().id()); assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(2); assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(8); @@ -349,7 +349,7 @@ public void replaceSubtree_macroReplacedWithConstExpr_macroCallCleared() throws CelAbstractSyntaxTree ast2 = CEL.compile("1").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().expr().id()); + MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().id()); assertThat(mutatedAst.getSource().getMacroCalls()).isEmpty(); assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("1"); @@ -368,7 +368,7 @@ public void globalCallExpr_replaceRoot() throws Exception { MutableAst.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 4); - assertThat(replacedAst.getExpr()).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(10))); + assertThat(replacedAst.getExpr()).isEqualTo(CelExpr.ofConstantExpr(7, CelConstant.ofValue(10))); } @Test @@ -673,7 +673,12 @@ private void assertConsistentMacroCalls(CelAbstractSyntaxTree ast) { node -> { CelExpr e = allExprs.get(node.id()); if (e != null) { - assertThat(node).isEqualTo(e); + assertThat(node.id()).isEqualTo(e.id()); + if (e.exprKind().getKind().equals(Kind.COMPREHENSION)) { + assertThat(node.exprKind().getKind()).isEqualTo(Kind.NOT_SET); + } else { + assertThat(node.exprKind().getKind()).isEqualTo(e.exprKind().getKind()); + } } }); } diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel index 3e111984..a4f8143f 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -10,16 +10,22 @@ java_library( "//:java_truth", "//bundle:cel", "//common", + "//common:compiler_common", "//common:options", + "//common/ast", "//common/resources/testdata/proto3:test_all_types_java_proto", "//common/types", + "//extensions", "//extensions:optional_library", "//optimizer", "//optimizer:optimization_exception", "//optimizer:optimizer_builder", + "//optimizer/optimizers:common_subexpression_elimination", "//optimizer/optimizers:constant_folding", "//parser:macro", + "//parser:operator", "//parser:unparser", + "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:junit_junit", ], diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index f47481be..eaa8c6ce 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -288,6 +288,98 @@ public void constantFold_noOp(String source) throws Exception { assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); } + @Test + public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNotSet() + throws Exception { + Cel cel = + CelFactory.standardCelBuilder() + .addVar("x", SimpleType.DYN) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions(CelOptions.current().populateMacroCalls(true).build()) + .build(); + CelOptimizer celOptimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(cel) + .addAstOptimizers(ConstantFoldingOptimizer.INSTANCE) + .build(); + CelAbstractSyntaxTree ast = + cel.compile("[1, 1 + 1, 1 + 1+ 1].map(i, i).filter(j, j % 2 == x)").getAst(); + + CelAbstractSyntaxTree optimizedAst = celOptimizer.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo("[1, 2, 3].map(i, i).filter(j, j % 2 == x)"); + assertThat(optimizedAst.getSource().getMacroCalls()).hasSize(2); + assertThat(optimizedAst.getSource().getMacroCalls().get(1L).toString()) + .isEqualTo( + "CALL [0] {\n" + + " function: filter\n" + + " target: {\n" + + " NOT_SET [2] {}\n" + + " }\n" + + " args: {\n" + + " IDENT [25] {\n" + + " name: j\n" + + " }\n" + + " CALL [17] {\n" + + " function: _==_\n" + + " args: {\n" + + " CALL [18] {\n" + + " function: _%_\n" + + " args: {\n" + + " IDENT [19] {\n" + + " name: j\n" + + " }\n" + + " CONSTANT [20] { value: 2 }\n" + + " }\n" + + " }\n" + + " IDENT [21] {\n" + + " name: x\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"); + assertThat(optimizedAst.getSource().getMacroCalls().get(2L).toString()) + .isEqualTo( + "CALL [0] {\n" + + " function: map\n" + + " target: {\n" + + " CREATE_LIST [3] {\n" + + " elements: {\n" + + " CONSTANT [4] { value: 1 }\n" + + " CONSTANT [5] { value: 2 }\n" + + " CONSTANT [6] { value: 3 }\n" + + " }\n" + + " }\n" + + " }\n" + + " args: {\n" + + " IDENT [28] {\n" + + " name: i\n" + + " }\n" + + " IDENT [12] {\n" + + " name: i\n" + + " }\n" + + " }\n" + + "}"); + } + + @Test + public void constantFold_astProducesConsistentlyNumberedIds() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[1] + [2] + [3]").getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(optimizedAst.getExpr().toString()) + .isEqualTo( + "CREATE_LIST [1] {\n" + + " elements: {\n" + + " CONSTANT [2] { value: 1 }\n" + + " CONSTANT [3] { value: 2 }\n" + + " CONSTANT [4] { value: 3 }\n" + + " }\n" + + "}"); + } + @Test public void maxIterationCountReached_throws() throws Exception { StringBuilder sb = new StringBuilder(); diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java new file mode 100644 index 00000000..15741750 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -0,0 +1,564 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer.optimizers; + +import static com.google.common.truth.Truth.assertThat; +import static dev.cel.common.CelOverloadDecl.newGlobalOverload; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelBuilder; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOptions; +import dev.cel.common.ast.CelConstant; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.extensions.CelExtensions; +import dev.cel.extensions.CelOptionalLibrary; +import dev.cel.optimizer.CelOptimizationException; +import dev.cel.optimizer.CelOptimizer; +import dev.cel.optimizer.CelOptimizerFactory; +import dev.cel.optimizer.optimizers.SubexpressionOptimizer.SubexpressionOptimizerOptions; +import dev.cel.parser.CelStandardMacro; +import dev.cel.parser.CelUnparser; +import dev.cel.parser.CelUnparserFactory; +import dev.cel.parser.Operator; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.NestedTestAllTypes; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class SubexpressionOptimizerTest { + + private static final Cel CEL = newCelBuilder().build(); + + private static final CelOptimizer CEL_OPTIMIZER = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build())) + .build(); + + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + + private static final TestAllTypes TEST_ALL_TYPES_INPUT = + TestAllTypes.newBuilder() + .setSingleInt64(3L) + .setSingleInt32(5) + .setOneofType( + NestedTestAllTypes.newBuilder() + .setPayload( + TestAllTypes.newBuilder() + .setSingleInt32(8) + .setSingleInt64(10L) + .putMapInt32Int64(0, 1) + .putMapInt32Int64(1, 5) + .putMapInt32Int64(2, 2))) + .build(); + + private static CelBuilder newCelBuilder() { + return CelFactory.standardCelBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .setOptions( + CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build()) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "custom_func", + newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT))) + .addVar("x", SimpleType.DYN) + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())); + } + + @Test + public void cse_producesOptimizedAst() throws Exception { + CelAbstractSyntaxTree ast = + CEL.compile("size([0]) + size([0]) + size([1,2]) + size([1,2])").getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(6); + assertThat(optimizedAst.getExpr().toString()) + .isEqualTo( + "COMPREHENSION [1] {\n" + + " iter_var: #unused\n" + + " iter_range: {\n" + + " CREATE_LIST [2] {\n" + + " elements: {\n" + + " }\n" + + " }\n" + + " }\n" + + " accu_var: @r1\n" + + " accu_init: {\n" + + " CALL [3] {\n" + + " function: size\n" + + " args: {\n" + + " CREATE_LIST [4] {\n" + + " elements: {\n" + + " CONSTANT [5] { value: 1 }\n" + + " CONSTANT [6] { value: 2 }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " loop_condition: {\n" + + " CONSTANT [7] { value: false }\n" + + " }\n" + + " loop_step: {\n" + + " IDENT [8] {\n" + + " name: @r1\n" + + " }\n" + + " }\n" + + " result: {\n" + + " CALL [9] {\n" + + " function: _+_\n" + + " args: {\n" + + " CALL [10] {\n" + + " function: _+_\n" + + " args: {\n" + + " COMPREHENSION [11] {\n" + + " iter_var: #unused\n" + + " iter_range: {\n" + + " CREATE_LIST [12] {\n" + + " elements: {\n" + + " }\n" + + " }\n" + + " }\n" + + " accu_var: @r0\n" + + " accu_init: {\n" + + " CALL [13] {\n" + + " function: size\n" + + " args: {\n" + + " CREATE_LIST [14] {\n" + + " elements: {\n" + + " CONSTANT [15] { value: 0 }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " loop_condition: {\n" + + " CONSTANT [16] { value: false }\n" + + " }\n" + + " loop_step: {\n" + + " IDENT [17] {\n" + + " name: @r0\n" + + " }\n" + + " }\n" + + " result: {\n" + + " CALL [18] {\n" + + " function: _+_\n" + + " args: {\n" + + " IDENT [19] {\n" + + " name: @r0\n" + + " }\n" + + " IDENT [20] {\n" + + " name: @r0\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " IDENT [21] {\n" + + " name: @r1\n" + + " }\n" + + " }\n" + + " }\n" + + " IDENT [22] {\n" + + " name: @r1\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"); + } + + private enum CseTestCase { + SIZE_1("size([1,2]) + size([1,2]) + 1 == 5", "cel.bind(@r0, size([1, 2]), @r0 + @r0) + 1 == 5"), + SIZE_2( + "2 + size([1,2]) + size([1,2]) + 1 == 7", + "cel.bind(@r0, size([1, 2]), 2 + @r0 + @r0) + 1 == 7"), + SIZE_3( + "size([0]) + size([0]) + size([1,2]) + size([1,2]) == 6", + "cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), @r0 + @r0) + @r1 + @r1) == 6"), + SIZE_4( + "5 + size([0]) + size([0]) + size([1,2]) + size([1,2]) + " + + "size([1,2,3]) + size([1,2,3]) == 17", + "cel.bind(@r2, size([1, 2, 3]), cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0]), 5 +" + + " @r0 + @r0) + @r1 + @r1) + @r2 + @r2) == 17"), + /** + * Unparsed form is: + * + *
+     * {@code
+     * cel.bind(@r0, timestamp(int(timestamp(1000000000))).getFullYear(),
+     *    cel.bind(@r3, timestamp(int(timestamp(75))),
+     *      cel.bind(@r2, timestamp(int(timestamp(200))).getFullYear(),
+     *        cel.bind(@r1, timestamp(int(timestamp(50))),
+     *          @r0 + @r3.getFullYear() + @r1.getFullYear() + @r0 + @r1.getSeconds()
+     *        ) + @r2 + @r2
+     *      ) + @r3.getMinutes()
+     *    ) + @r0
+     *) == 13934
+     * }
+     * 
+ */ + TIMESTAMP( + "timestamp(int(timestamp(1000000000))).getFullYear() +" + + " timestamp(int(timestamp(75))).getFullYear() + " + + " timestamp(int(timestamp(50))).getFullYear() + " + + " timestamp(int(timestamp(1000000000))).getFullYear() + " + + " timestamp(int(timestamp(50))).getSeconds() + " + + " timestamp(int(timestamp(200))).getFullYear() + " + + " timestamp(int(timestamp(200))).getFullYear() + " + + " timestamp(int(timestamp(75))).getMinutes() + " + + " timestamp(int(timestamp(1000000000))).getFullYear() == 13934", + "cel.bind(@r0, timestamp(int(timestamp(1000000000))).getFullYear(), " + + "cel.bind(@r3, timestamp(int(timestamp(75))), " + + "cel.bind(@r2, timestamp(int(timestamp(200))).getFullYear(), " + + "cel.bind(@r1, timestamp(int(timestamp(50))), " + + "@r0 + @r3.getFullYear() + @r1.getFullYear() + " + + "@r0 + @r1.getSeconds()) + @r2 + @r2) + @r3.getMinutes()) + @r0) == 13934"), + MAP_INDEX( + "{\"a\": 2}[\"a\"] + {\"a\": 2}[\"a\"] * {\"a\": 2}[\"a\"] == 6", + "cel.bind(@r0, {\"a\": 2}[\"a\"], @r0 + @r0 * @r0) == 6"), + /** + * Input map is: + * + *
{@code
+     * {
+     *    "a": { "b": 1 },
+     *    "c": { "b": 1 },
+     *    "d": {
+     *       "e": { "b": 1 }
+     *    },
+     *    "e":{
+     *       "e": { "b": 1 }
+     *    }
+     * }
+     * }
+ */ + NESTED_MAP_CONSTRUCTION( + "size({'a': {'b': 1}, 'c': {'b': 1}, 'd': {'e': {'b': 1}}, 'e': {'e': {'b': 1}}}) == 4", + "size(cel.bind(@r0, {\"b\": 1}, cel.bind(@r1, {\"e\": @r0}, {\"a\": @r0, \"c\": @r0, \"d\":" + + " @r1, \"e\": @r1}))) == 4"), + NESTED_LIST_CONSTRUCTION( + "size([1, [1,2,3,4], 2, [1,2,3,4], 5, [1,2,3,4], 7, [[1,2], [1,2,3,4]], [1,2]]) == 9", + "size(cel.bind(@r0, [1, 2, 3, 4], " + + "cel.bind(@r1, [1, 2], [1, @r0, 2, @r0, 5, @r0, 7, [@r1, @r0], @r1]))) == 9"), + SELECT( + "msg.single_int64 + msg.single_int64 == 6", + "cel.bind(@r0, msg.single_int64, @r0 + @r0) == 6"), + SELECT_NESTED( + "msg.oneof_type.payload.single_int64 + msg.oneof_type.payload.single_int32 + " + + "msg.oneof_type.payload.single_int64 + " + + "msg.single_int64 + msg.oneof_type.payload.oneof_type.payload.single_int64 == 31", + "cel.bind(@r0, msg.oneof_type.payload, " + + "cel.bind(@r1, @r0.single_int64, @r1 + @r0.single_int32 + @r1) + " + + "msg.single_int64 + @r0.oneof_type.payload.single_int64) == 31"), + SELECT_NESTED_MESSAGE_MAP_INDEX_1( + "msg.oneof_type.payload.map_int32_int64[1] + " + + "msg.oneof_type.payload.map_int32_int64[1] + " + + "msg.oneof_type.payload.map_int32_int64[1] == 15", + "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64[1], @r0 + @r0 + @r0) == 15"), + SELECT_NESTED_MESSAGE_MAP_INDEX_2( + "msg.oneof_type.payload.map_int32_int64[0] + " + + "msg.oneof_type.payload.map_int32_int64[1] + " + + "msg.oneof_type.payload.map_int32_int64[2] == 8", + "cel.bind(@r0, msg.oneof_type.payload.map_int32_int64, @r0[0] + @r0[1] + @r0[2]) == 8"), + TERNARY( + "(msg.single_int64 > 0 ? msg.single_int64 : 0) == 3", + "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? @r0 : 0) == 3"), + TERNARY_BIND_RHS_ONLY( + "false ? false : (msg.single_int64) + ((msg.single_int64 + 1) * 2) == 11", + "false ? false : (cel.bind(@r0, msg.single_int64, @r0 + (@r0 + 1) * 2) == 11)"), + NESTED_TERNARY( + "(msg.single_int64 > 0 ? (msg.single_int32 > 0 ? " + + "msg.single_int64 + msg.single_int32 : 0) : 0) == 8", + "cel.bind(@r0, msg.single_int64, (@r0 > 0) ? " + + "cel.bind(@r1, msg.single_int32, (@r1 > 0) ? (@r0 + @r1) : 0) : 0) == 8"), + MULTIPLE_MACROS( + "size([[1].exists(x, x > 0)]) + size([[1].exists(x, x > 0)]) + " + + "size([[2].exists(x, x > 1)]) + size([[2].exists(x, x > 1)]) == 4", + "cel.bind(@r1, size([[2].exists(x, x > 1)]), " + + "cel.bind(@r0, size([[1].exists(x, x > 0)]), @r0 + @r0) + @r1 + @r1) == 4"), + NESTED_MACROS( + "[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]", + "cel.bind(@r0, [1, 2, 3], @r0.map(i, @r0.map(i, i + 1))) == cel.bind(@r1, [2, 3, 4], [@r1," + + " @r1, @r1])"), + MACRO_SHADOWED_VARIABLE( + // Macro variable `x` in .exists is shadowed. + // This is left intact due to the fact that loop condition is not optimized at the moment. + "[x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3", + "cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(x, x - 1 > 3) ||" + + " @r1))"); + + private final String source; + private final String unparsed; + + CseTestCase(String source, String unparsed) { + this.source = source; + this.unparsed = unparsed; + } + } + + @Test + public void cse_withMacroMapPopulated_success(@TestParameter CseTestCase testCase) + throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(testCase.source).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat( + CEL.createProgram(optimizedAst) + .eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L))) + .isEqualTo(true); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(testCase.unparsed); + } + + @Test + public void cse_withoutMacroMap_success(@TestParameter CseTestCase testCase) throws Exception { + Cel celWithoutMacroMap = + newCelBuilder() + .setOptions( + CelOptions.current().populateMacroCalls(false).enableTimestampEpoch(true).build()) + .build(); + CelAbstractSyntaxTree ast = celWithoutMacroMap.compile(testCase.source).getAst(); + + CelAbstractSyntaxTree optimizedAst = + CelOptimizerFactory.standardCelOptimizerBuilder(celWithoutMacroMap) + .addAstOptimizers(SubexpressionOptimizer.getInstance()) + .build() + .optimize(ast); + + assertThat(optimizedAst.getSource().getMacroCalls()).isEmpty(); + assertThat( + celWithoutMacroMap + .createProgram(optimizedAst) + .eval(ImmutableMap.of("msg", TEST_ALL_TYPES_INPUT, "x", 5L))) + .isEqualTo(true); + } + + @Test + // Nothing to optimize + @TestParameters("{source: 'size(\"hello\")'}") + // Constants and identifiers + @TestParameters("{source: '2 + 2 + 2 + 2'}") + @TestParameters("{source: 'x + x + x + x'}") + @TestParameters("{source: 'true == true && false == false'}") + // Constants and identifiers within a function + @TestParameters("{source: 'size(\"hello\" + \"hello\" + \"hello\")'}") + @TestParameters("{source: 'string(x + x + x)'}") + // Non-standard functions are considered non-pure for time being + @TestParameters("{source: 'custom_func(1) + custom_func(1)'}") + // Duplicated but nested calls. + @TestParameters("{source: 'int(timestamp(int(timestamp(1000000000))))'}") + // Loop condition is not optimized at the moment. This requires mangling. + @TestParameters("{source: '[\"foo\", \"bar\"].map(x, [x + x, x + x]).map(x, [x + x, x + x])'}") + // Ternary with presence test is not supported yet. + @TestParameters("{source: 'has(msg.single_any) ? msg.single_any : 10'}") + public void cse_noop(String source) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(ast.getExpr()).isEqualTo(optimizedAst.getExpr()); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo(source); + } + + @Test + public void cse_largeCalcExpr() throws Exception { + StringBuilder sb = new StringBuilder(); + int limit = 40; + for (int i = 0; i < limit; i++) { + sb.append("size([1]) + "); + sb.append("size([1,2]) + "); + sb.append("size([1,2,3]) +"); + sb.append("size([1,2,3,4])"); + if (i < limit - 1) { + sb.append("+"); + } + } + CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo( + "cel.bind(@r3, size([1, 2, 3, 4]), cel.bind(@r2, size([1, 2, 3]), cel.bind(@r1," + + " size([1, 2]), cel.bind(@r0, size([1]), @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2" + + " + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 +" + + " @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 +" + + " @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 +" + + " @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 +" + + " @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 +" + + " @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 +" + + " @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 +" + + " @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 +" + + " @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 +" + + " @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 +" + + " @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0 + @r1 +" + + " @r2 + @r3 + @r0 + @r1 + @r2 + @r3 + @r0) + @r1) + @r2) + @r3)"); + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(400L); + } + + @Test + public void cse_largeNestedBinds() throws Exception { + StringBuilder sb = new StringBuilder(); + int limit = 50; + for (int i = 0; i < limit; i++) { + sb.append(String.format("size([%d, %d]) + ", i, i + 1)); + sb.append(String.format("size([%d, %d]) ", i, i + 1)); + if (i < limit - 1) { + sb.append("+"); + } + } + CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo( + "cel.bind(@r49, size([49, 50]), cel.bind(@r48, size([48, 49]), cel.bind(@r47, size([47," + + " 48]), cel.bind(@r46, size([46, 47]), cel.bind(@r45, size([45, 46])," + + " cel.bind(@r44, size([44, 45]), cel.bind(@r43, size([43, 44]), cel.bind(@r42," + + " size([42, 43]), cel.bind(@r41, size([41, 42]), cel.bind(@r40, size([40, 41])," + + " cel.bind(@r39, size([39, 40]), cel.bind(@r38, size([38, 39]), cel.bind(@r37," + + " size([37, 38]), cel.bind(@r36, size([36, 37]), cel.bind(@r35, size([35, 36])," + + " cel.bind(@r34, size([34, 35]), cel.bind(@r33, size([33, 34]), cel.bind(@r32," + + " size([32, 33]), cel.bind(@r31, size([31, 32]), cel.bind(@r30, size([30, 31])," + + " cel.bind(@r29, size([29, 30]), cel.bind(@r28, size([28, 29]), cel.bind(@r27," + + " size([27, 28]), cel.bind(@r26, size([26, 27]), cel.bind(@r25, size([25, 26])," + + " cel.bind(@r24, size([24, 25]), cel.bind(@r23, size([23, 24]), cel.bind(@r22," + + " size([22, 23]), cel.bind(@r21, size([21, 22]), cel.bind(@r20, size([20, 21])," + + " cel.bind(@r19, size([19, 20]), cel.bind(@r18, size([18, 19]), cel.bind(@r17," + + " size([17, 18]), cel.bind(@r16, size([16, 17]), cel.bind(@r15, size([15, 16])," + + " cel.bind(@r14, size([14, 15]), cel.bind(@r13, size([13, 14]), cel.bind(@r12," + + " size([12, 13]), cel.bind(@r11, size([11, 12]), cel.bind(@r10, size([10, 11])," + + " cel.bind(@r9, size([9, 10]), cel.bind(@r8, size([8, 9]), cel.bind(@r7, size([7," + + " 8]), cel.bind(@r6, size([6, 7]), cel.bind(@r5, size([5, 6]), cel.bind(@r4," + + " size([4, 5]), cel.bind(@r3, size([3, 4]), cel.bind(@r2, size([2, 3])," + + " cel.bind(@r1, size([1, 2]), cel.bind(@r0, size([0, 1]), @r0 + @r0) + @r1 + @r1)" + + " + @r2 + @r2) + @r3 + @r3) + @r4 + @r4) + @r5 + @r5) + @r6 + @r6) + @r7 + @r7) +" + + " @r8 + @r8) + @r9 + @r9) + @r10 + @r10) + @r11 + @r11) + @r12 + @r12) + @r13 +" + + " @r13) + @r14 + @r14) + @r15 + @r15) + @r16 + @r16) + @r17 + @r17) + @r18 +" + + " @r18) + @r19 + @r19) + @r20 + @r20) + @r21 + @r21) + @r22 + @r22) + @r23 +" + + " @r23) + @r24 + @r24) + @r25 + @r25) + @r26 + @r26) + @r27 + @r27) + @r28 +" + + " @r28) + @r29 + @r29) + @r30 + @r30) + @r31 + @r31) + @r32 + @r32) + @r33 +" + + " @r33) + @r34 + @r34) + @r35 + @r35) + @r36 + @r36) + @r37 + @r37) + @r38 +" + + " @r38) + @r39 + @r39) + @r40 + @r40) + @r41 + @r41) + @r42 + @r42) + @r43 +" + + " @r43) + @r44 + @r44) + @r45 + @r45) + @r46 + @r46) + @r47 + @r47) + @r48 +" + + " @r48) + @r49 + @r49)"); + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(200L); + } + + @Test + public void cse_largeNestedMacro() throws Exception { + StringBuilder sb = new StringBuilder(); + sb.append("size([1,2,3]"); + int limit = 8; + for (int i = 0; i < limit; i++) { + sb.append(".map(i, [1, 2, 3]"); + } + for (int i = 0; i < limit; i++) { + sb.append(")"); + } + sb.append(")"); + String nestedMapCallExpr = sb.toString(); // size([1,2,3].map(i, [1,2,3].map(i, [1,2,3].map(... + // Add this large macro call 8 times + for (int i = 0; i < limit; i++) { + sb.append("+"); + sb.append(nestedMapCallExpr); + } + CelAbstractSyntaxTree ast = CEL.compile(sb.toString()).getAst(); + + CelAbstractSyntaxTree optimizedAst = CEL_OPTIMIZER.optimize(ast); + + assertThat(CEL_UNPARSER.unparse(optimizedAst)) + .isEqualTo( + "cel.bind(@r0, [1, 2, 3], cel.bind(@r1, size(@r0.map(i, @r0.map(i, @r0.map(i," + + " @r0.map(i, @r0.map(i, @r0.map(i, @r0.map(i, @r0.map(i, [1, 2, 3]))))))))), @r1" + + " + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1 + @r1))"); + assertThat(CEL.createProgram(optimizedAst).eval()).isEqualTo(27); + } + + @Test + public void cse_applyConstFoldingAfter() throws Exception { + CelAbstractSyntaxTree ast = + CEL.compile("size([1+1+1]) + size([1+1+1]) + size([1,1+1+1]) + size([1,1+1+1]) + x") + .getAst(); + CelOptimizer optimizer = + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()), + ConstantFoldingOptimizer.INSTANCE) + .build(); + + CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast); + + assertThat(optimizedAst.getExpr()) + .isEqualTo( + CelExpr.ofCallExpr( + 1L, + Optional.empty(), + Operator.ADD.getFunction(), + ImmutableList.of( + CelExpr.ofConstantExpr(2L, CelConstant.ofValue(6L)), + CelExpr.ofIdentExpr(3L, "x")))); + assertThat(CEL_UNPARSER.unparse(optimizedAst)).isEqualTo("6 + x"); + } + + @Test + public void maxIterationLimitReached_throws() throws Exception { + StringBuilder largeExprBuilder = new StringBuilder(); + int maxIterationLimit = 100; + for (int i = 0; i < maxIterationLimit; i++) { + largeExprBuilder.append("[1,2]"); + if (i < maxIterationLimit - 1) { + largeExprBuilder.append("+"); + } + } + CelAbstractSyntaxTree ast = CEL.compile(largeExprBuilder.toString()).getAst(); + + CelOptimizationException e = + assertThrows( + CelOptimizationException.class, + () -> + CelOptimizerFactory.standardCelOptimizerBuilder(CEL) + .addAstOptimizers( + SubexpressionOptimizer.newInstance( + SubexpressionOptimizerOptions.newBuilder() + .maxIterationLimit(maxIterationLimit) + .build())) + .build() + .optimize(ast)); + assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached."); + } +}