diff --git a/optimizer/BUILD.bazel b/optimizer/BUILD.bazel index f0bc0bae..1dd584b4 100644 --- a/optimizer/BUILD.bazel +++ b/optimizer/BUILD.bazel @@ -25,8 +25,6 @@ java_library( java_library( name = "mutable_ast", - testonly = 1, - visibility = ["//optimizer/src/test/java/dev/cel/optimizer:__pkg__"], exports = ["//optimizer/src/main/java/dev/cel/optimizer:mutable_ast"], ) diff --git a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel index 454ab828..954deb74 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel @@ -66,18 +66,19 @@ java_library( tags = [ ], deps = [ - ":mutable_ast", ":optimization_exception", "//bundle:cel", "//common", - "//common/ast", "//common/navigation", ], ) java_library( name = "mutable_ast", - srcs = ["MutableAst.java"], + srcs = [ + "MutableAst.java", + "MutableExprVisitor.java", + ], tags = [ ], deps = [ @@ -87,6 +88,7 @@ java_library( "//common/ast", "//common/ast:expr_factory", "//common/navigation", + "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], ) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java index feac4ce9..b9f35dcf 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -16,7 +16,6 @@ import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; -import dev.cel.common.ast.CelExpr; import dev.cel.common.navigation.CelNavigableAst; /** Public interface for performing a single, custom optimization on an AST. */ @@ -25,106 +24,4 @@ public interface CelAstOptimizer { /** Optimizes a single AST. */ CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) throws CelOptimizationException; - - /** - * Replaces a subtree in the given expression node. This operation is intended for AST - * optimization purposes. - * - *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and - * additionally verify that the resulting AST is semantically valid. - * - *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision - * between the nodes. The renumbering occurs even if the subtree was not replaced. - * - * @param celExpr Original expression node to rewrite. - * @param newExpr New CelExpr to replace the subtree with. - * @param exprIdToReplace Expression id of the subtree that is getting replaced. - */ - default CelExpr replaceSubtree(CelExpr celExpr, CelExpr newExpr, long exprIdToReplace) { - return MutableAst.replaceSubtree(celExpr, newExpr, exprIdToReplace); - } - - /** - * Replaces a subtree in the given AST. This operation is intended for AST optimization purposes. - * - *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and - * additionally verify that the resulting AST is semantically valid. - * - *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision - * between the nodes. The renumbering occurs even if the subtree was not replaced. - * - *

This will scrub out the description, positions and line offsets from {@code CelSource}. If - * the source contains macro calls, its call IDs will be to be consistent with the renumbered IDs - * in the AST. - * - * @param ast Original ast to mutate. - * @param newExpr New CelExpr to replace the subtree with. - * @param exprIdToReplace Expression id of the subtree that is getting replaced. - */ - default CelAbstractSyntaxTree replaceSubtree( - CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { - return MutableAst.replaceSubtree(ast, newExpr, exprIdToReplace); - } - - /** - * Generates a new bind macro using the provided initialization and result expression, then - * replaces the subtree using the new bind expr at the designated expr ID. - * - *

The bind call takes the format of: {@code cel.bind(varInit, varName, resultExpr)} - * - * @param ast Original ast to mutate. - * @param varName New variable name for the bind macro call. - * @param varInit Initialization expression to bind to the local variable. - * @param resultExpr Result expression - * @param exprIdToReplace Expression ID of the subtree that is getting replaced. - */ - default CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( - CelAbstractSyntaxTree ast, - String varName, - CelExpr varInit, - CelExpr resultExpr, - long exprIdToReplace) { - return MutableAst.replaceSubtreeWithNewBindMacro( - ast, varName, varInit, resultExpr, exprIdToReplace); - } - - /** - * Replaces all comprehension identifier names with a unique name based on the given prefix. - * - *

The purpose of this is to avoid errors that can be caused by shadowed variables while - * augmenting an AST. As an example: {@code [2, 3].exists(x, x - 1 > 3) || x - 1 > 3}. Note that - * the scoping of `x - 1` is different between th two LOGICAL_OR branches. Iteration variable `x` - * in `exists` will be mangled to {@code [2, 3].exists(@c0, @c0 - 1 > 3) || x - 1 > 3} to avoid - * erroneously extracting x - 1 as common subexpression. - * - *

The expression IDs are not modified when the identifier names are changed. - * - *

Iteration variables in comprehensions are numbered based on their comprehension nesting - * levels. Examples: - * - *

- * - * @param ast AST to mutate - * @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will - * produce @c0, @c1, @c2... as new names. - */ - default CelAbstractSyntaxTree mangleComprehensionIdentifierNames( - CelAbstractSyntaxTree ast, String newIdentPrefix) { - return MutableAst.mangleComprehensionIdentifierNames(ast, newIdentPrefix); - } - - /** Sets all expr IDs in the expression tree to 0. */ - default CelExpr clearExprIds(CelExpr celExpr) { - return MutableAst.clearExprIds(celExpr); - } - - /** Renumbers all the expr IDs in the given AST in a consecutive manner starting from 1. */ - default CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) { - return MutableAst.renumberIdsConsecutively(ast); - } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 1c342716..5562428c 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -19,23 +19,20 @@ import static java.lang.Math.max; import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelSource; -import dev.cel.common.annotations.Internal; import dev.cel.common.ast.CelExpr; -import dev.cel.common.ast.CelExpr.CelCall; -import dev.cel.common.ast.CelExpr.CelComprehension; -import dev.cel.common.ast.CelExpr.CelCreateList; -import dev.cel.common.ast.CelExpr.CelCreateMap; -import dev.cel.common.ast.CelExpr.CelCreateStruct; import dev.cel.common.ast.CelExpr.CelIdent; -import dev.cel.common.ast.CelExpr.CelSelect; import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.ast.CelExprFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator; +import dev.cel.common.ast.CelExprIdGeneratorFactory.MonotonicIdGenerator; import dev.cel.common.ast.CelExprIdGeneratorFactory.StableIdGenerator; import dev.cel.common.navigation.CelNavigableAst; import dev.cel.common.navigation.CelNavigableExpr; @@ -44,100 +41,102 @@ import java.util.NoSuchElementException; import java.util.Optional; -/** MutableAst contains logic for mutating a {@link CelExpr}. */ -@Internal -final class MutableAst { - private static final int MAX_ITERATION_COUNT = 1000; +/** MutableAst contains logic for mutating a {@link CelAbstractSyntaxTree}. */ +@Immutable +public final class MutableAst { private static final ExprIdGenerator NO_OP_ID_GENERATOR = id -> id; + private final long iterationLimit; - private final CelExpr.Builder newExpr; - private final ExprIdGenerator celExprIdGenerator; - private int iterationCount; - private long exprIdToReplace; + /** + * Returns a new instance of a Mutable AST with the iteration limit set. + * + *

Mutation is performed by walking the existing AST until the expression node to replace is + * found, then the new subtree is walked to complete the mutation. Visiting of each node + * increments the iteration counter. Replace subtree operations will throw an exception if this + * counter reaches the limit. + * + * @param iterationLimit Must be greater than 0. + */ + public static MutableAst newInstance(long iterationLimit) { + return new MutableAst(iterationLimit); + } - private MutableAst(ExprIdGenerator celExprIdGenerator, CelExpr.Builder newExpr, long exprId) { - this.celExprIdGenerator = celExprIdGenerator; - this.newExpr = newExpr; - this.exprIdToReplace = exprId; + private MutableAst(long iterationLimit) { + Preconditions.checkState(iterationLimit > 0L); + this.iterationLimit = iterationLimit; } /** Replaces all the expression IDs in the expression tree with 0. */ - static CelExpr clearExprIds(CelExpr celExpr) { + public CelExpr clearExprIds(CelExpr celExpr) { return renumberExprIds((unused) -> 0, celExpr.toBuilder()).build(); } - /** Mutates the given {@link CelExpr} by replacing a subtree at the given index. */ - static CelExpr replaceSubtree(CelExpr expr, CelExpr newExpr, long exprIdToReplace) { - return replaceSubtree( - CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()), - CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()), + /** + * Replaces a subtree in the given expression node. This operation is intended for AST + * optimization purposes. + * + *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision + * between the nodes. The renumbering occurs even if the subtree was not replaced. + * + *

If the ability to unparse an expression containing a macro call must be retained, use {@link + * #replaceSubtree(CelAbstractSyntaxTree, CelExpr, long) instead.} + * + * @param celExpr Original expression node to rewrite. + * @param newExpr New CelExpr to replace the subtree with. + * @param exprIdToReplace Expression id of the subtree that is getting replaced. + */ + public CelExpr replaceSubtree(CelExpr celExpr, CelExpr newExpr, long exprIdToReplace) { + MonotonicIdGenerator monotonicIdGenerator = + CelExprIdGeneratorFactory.newMonotonicIdGenerator(0); + return mutateExpr( + unused -> monotonicIdGenerator.nextExprId(), + celExpr.toBuilder(), + newExpr.toBuilder(), exprIdToReplace) - .getExpr(); + .build(); } /** - * Mutates the given AST by replacing a subtree at a given index. + * Replaces a subtree in the given AST. This operation is intended for AST optimization purposes. * - * @param ast Existing AST being mutated - * @param newExpr New subtree to perform the replacement with. - * @param exprIdToReplace The expr ID in the existing AST to replace the subtree at. + *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision + * between the nodes. The renumbering occurs even if the subtree was not replaced. + * + *

This will scrub out the description, positions and line offsets from {@code CelSource}. If + * the source contains macro calls, its call IDs will be to be consistent with the renumbered IDs + * in the AST. + * + * @param ast Original ast to mutate. + * @param newExpr New CelExpr to replace the subtree with. + * @param exprIdToReplace Expression id of the subtree that is getting replaced. */ - static CelAbstractSyntaxTree replaceSubtree( + public CelAbstractSyntaxTree replaceSubtree( CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { - return replaceSubtree( + return replaceSubtreeWithNewAst( ast, CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()), exprIdToReplace); } /** - * Mutates the given AST by replacing a subtree at a given index. + * Generates a new bind macro using the provided initialization and result expression, then + * replaces the subtree using the new bind expr at the designated expr ID. * - * @param ast Existing AST being mutated - * @param newAst New subtree to perform the replacement with. If the subtree has a macro map - * populated, its macro source is merged with the existing AST's after normalization. - * @param exprIdToReplace The expr ID in the existing AST to replace the subtree at. + *

The bind call takes the format of: {@code cel.bind(varInit, varName, resultExpr)} + * + * @param ast Original ast to mutate. + * @param varName New variable name for the bind macro call. + * @param varInit Initialization expression to bind to the local variable. + * @param resultExpr Result expression + * @param exprIdToReplace Expression ID of the subtree that is getting replaced. */ - static CelAbstractSyntaxTree replaceSubtree( - CelAbstractSyntaxTree ast, CelAbstractSyntaxTree newAst, long exprIdToReplace) { - // Stabilize the incoming AST by renumbering all of its expression IDs. - long maxId = max(getMaxId(ast), getMaxId(newAst)); - newAst = stabilizeAst(newAst, maxId); - - // Mutate the AST root with the new subtree. All the existing expr IDs are renumbered in the - // process, but its original IDs are memoized so that we can normalize the expr IDs - // in the macro source map. - StableIdGenerator stableIdGenerator = - CelExprIdGeneratorFactory.newStableIdGenerator(getMaxId(newAst)); - CelExpr.Builder mutatedRoot = - replaceSubtreeImpl( - stableIdGenerator::renumberId, - ast.getExpr().toBuilder(), - newAst.getExpr().toBuilder(), - exprIdToReplace); - - CelSource newAstSource = ast.getSource(); - if (!newAst.getSource().getMacroCalls().isEmpty()) { - // The root is mutated, but the expr IDs in the macro map needs to be normalized. - // In situations where an AST with a new macro map is being inserted (ex: new bind call), - // the new subtree's expr ID is not memoized in the stable ID generator because the ID never - // existed in the main AST. - // In this case, we forcibly memoize the new subtree ID with a newly generated ID so - // that the macro map IDs can be normalized properly. - stableIdGenerator.memoize( - newAst.getExpr().id(), stableIdGenerator.renumberId(exprIdToReplace)); - newAstSource = combine(newAstSource, newAst.getSource()); - } - - newAstSource = - normalizeMacroSource( - newAstSource, exprIdToReplace, mutatedRoot, stableIdGenerator::renumberId); - - return CelAbstractSyntaxTree.newParsedAst(mutatedRoot.build(), newAstSource); - } - - /** Replaces the subtree at the given ID with a newly created bind macro. */ - static CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( + public CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( CelAbstractSyntaxTree ast, String varName, CelExpr varInit, @@ -160,11 +159,12 @@ static CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( .addMacroCalls(bindMacro.bindExpr().id(), bindMacro.bindMacro()) .build(); - return replaceSubtree( + return replaceSubtreeWithNewAst( ast, CelAbstractSyntaxTree.newParsedAst(bindMacro.bindExpr(), celSource), exprIdToReplace); } - static CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) { + /** Renumbers all the expr IDs in the given AST in a consecutive manner starting from 1. */ + public CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) { StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); CelExpr.Builder root = renumberExprIds(stableIdGenerator::renumberId, ast.getExpr().toBuilder()); @@ -175,12 +175,37 @@ static CelAbstractSyntaxTree renumberIdsConsecutively(CelAbstractSyntaxTree ast) return CelAbstractSyntaxTree.newParsedAst(root.build(), newSource); } - static CelAbstractSyntaxTree mangleComprehensionIdentifierNames( + /** + * Replaces all comprehension identifier names with a unique name based on the given prefix. + * + *

The purpose of this is to avoid errors that can be caused by shadowed variables while + * augmenting an AST. As an example: {@code [2, 3].exists(x, x - 1 > 3) || x - 1 > 3}. Note that + * the scoping of `x - 1` is different between th two LOGICAL_OR branches. Iteration variable `x` + * in `exists` will be mangled to {@code [2, 3].exists(@c0, @c0 - 1 > 3) || x - 1 > 3} to avoid + * erroneously extracting x - 1 as common subexpression. + * + *

The expression IDs are not modified when the identifier names are changed. + * + *

Iteration variables in comprehensions are numbered based on their comprehension nesting + * levels. Examples: + * + *

+ * + * @param ast AST to mutate + * @param newIdentPrefix Prefix to use for new identifier names. For example, providing @c will + * produce @c0, @c1, @c2... as new names. + */ + public CelAbstractSyntaxTree mangleComprehensionIdentifierNames( CelAbstractSyntaxTree ast, String newIdentPrefix) { int iterCount; CelNavigableAst newNavigableAst = CelNavigableAst.fromAst(ast); - for (iterCount = 0; iterCount < MAX_ITERATION_COUNT; iterCount++) { - Optional maybeComprehensionExpr = + for (iterCount = 0; iterCount < iterationLimit; iterCount++) { + CelNavigableExpr comprehensionNode = newNavigableAst .getRoot() // This is important - mangling needs to happen bottom-up to avoid stepping over @@ -188,14 +213,15 @@ static CelAbstractSyntaxTree mangleComprehensionIdentifierNames( .allNodes(TraversalOrder.POST_ORDER) .filter(node -> node.getKind().equals(Kind.COMPREHENSION)) .filter(node -> !node.expr().comprehension().iterVar().startsWith(newIdentPrefix)) - .findAny(); - if (!maybeComprehensionExpr.isPresent()) { + .findAny() + .orElse(null); + if (comprehensionNode == null) { break; } - CelExpr.Builder comprehensionExpr = maybeComprehensionExpr.get().expr().toBuilder(); + CelExpr.Builder comprehensionExpr = comprehensionNode.expr().toBuilder(); String iterVar = comprehensionExpr.comprehension().iterVar(); - int comprehensionNestingLevel = countComprehensionNestingLevel(maybeComprehensionExpr.get()); + int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode); String mangledVarName = newIdentPrefix + comprehensionNestingLevel; CelExpr.Builder mutatedComprehensionExpr = @@ -218,20 +244,70 @@ static CelAbstractSyntaxTree mangleComprehensionIdentifierNames( CelAbstractSyntaxTree.newParsedAst(mutatedComprehensionExpr.build(), newSource)); } - if (iterCount >= MAX_ITERATION_COUNT) { + if (iterCount >= iterationLimit) { + // Note that it's generally impossible to reach this for a well-formed AST. The nesting level + // of AST being mutated is always deeper than the number of identifiers being mangled, thus + // the mutation operation should throw before we ever reach here. throw new IllegalStateException("Max iteration count reached."); } return newNavigableAst.getAst(); } - private static CelExpr.Builder mangleIdentsInComprehensionExpr( + /** + * Mutates the given AST by replacing a subtree at a given index. + * + * @param ast Existing AST being mutated + * @param newAst New subtree to perform the replacement with. If the subtree has a macro map + * populated, its macro source is merged with the existing AST's after normalization. + * @param exprIdToReplace The expr ID in the existing AST to replace the subtree at. + */ + @VisibleForTesting + CelAbstractSyntaxTree replaceSubtreeWithNewAst( + CelAbstractSyntaxTree ast, CelAbstractSyntaxTree newAst, long exprIdToReplace) { + // Stabilize the incoming AST by renumbering all of its expression IDs. + long maxId = max(getMaxId(ast), getMaxId(newAst)); + newAst = stabilizeAst(newAst, maxId); + + // Mutate the AST root with the new subtree. All the existing expr IDs are renumbered in the + // process, but its original IDs are memoized so that we can normalize the expr IDs + // in the macro source map. + StableIdGenerator stableIdGenerator = + CelExprIdGeneratorFactory.newStableIdGenerator(getMaxId(newAst)); + CelExpr.Builder mutatedRoot = + mutateExpr( + stableIdGenerator::renumberId, + ast.getExpr().toBuilder(), + newAst.getExpr().toBuilder(), + exprIdToReplace); + + CelSource newAstSource = ast.getSource(); + if (!newAst.getSource().getMacroCalls().isEmpty()) { + // The root is mutated, but the expr IDs in the macro map needs to be normalized. + // In situations where an AST with a new macro map is being inserted (ex: new bind call), + // the new subtree's expr ID is not memoized in the stable ID generator because the ID never + // existed in the main AST. + // In this case, we forcibly memoize the new subtree ID with a newly generated ID so + // that the macro map IDs can be normalized properly. + stableIdGenerator.memoize( + newAst.getExpr().id(), stableIdGenerator.renumberId(exprIdToReplace)); + newAstSource = combine(newAstSource, newAst.getSource()); + } + + newAstSource = + normalizeMacroSource( + newAstSource, exprIdToReplace, mutatedRoot, stableIdGenerator::renumberId); + + return CelAbstractSyntaxTree.newParsedAst(mutatedRoot.build(), newAstSource); + } + + private CelExpr.Builder mangleIdentsInComprehensionExpr( CelExpr.Builder root, CelExpr.Builder comprehensionExpr, String originalIterVar, String mangledVarName) { int iterCount; - for (iterCount = 0; iterCount < MAX_ITERATION_COUNT; iterCount++) { + for (iterCount = 0; iterCount < iterationLimit; iterCount++) { Optional identToMangle = CelNavigableExpr.fromExpr(comprehensionExpr.comprehension().loopStep()) .descendants() @@ -243,18 +319,18 @@ private static CelExpr.Builder mangleIdentsInComprehensionExpr( } comprehensionExpr = - replaceSubtreeImpl( + mutateExpr( NO_OP_ID_GENERATOR, comprehensionExpr, CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()), identToMangle.get().id()); } - if (iterCount >= MAX_ITERATION_COUNT) { + if (iterCount >= iterationLimit) { throw new IllegalStateException("Max iteration count reached."); } - return replaceSubtreeImpl( + return mutateExpr( NO_OP_ID_GENERATOR, root, comprehensionExpr.setComprehension( @@ -262,7 +338,7 @@ private static CelExpr.Builder mangleIdentsInComprehensionExpr( comprehensionExpr.id()); } - private static CelSource mangleIdentsInMacroSource( + private CelSource mangleIdentsInMacroSource( CelAbstractSyntaxTree ast, CelExpr.Builder mutatedComprehensionExpr, String originalIterVar, @@ -291,7 +367,7 @@ private static CelSource mangleIdentsInMacroSource( identToMangle.identOrDefault().name(), originalIterVar)); } macroExpr = - replaceSubtreeImpl( + mutateExpr( NO_OP_ID_GENERATOR, macroExpr, CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName(mangledVarName).build()), @@ -301,7 +377,7 @@ private static CelSource mangleIdentsInMacroSource( return newSource.build(); } - private static BindMacro newBindMacro( + private BindMacro newBindMacro( String varName, CelExpr varInit, CelExpr resultExpr, StableIdGenerator stableIdGenerator) { // Renumber incoming expression IDs in the init and result expression to avoid collision with // the main AST. Existing IDs are memoized for a macro source sanitization pass at the end @@ -348,7 +424,7 @@ private static CelSource combine(CelSource celSource1, CelSource celSource2) { * (monotonically increased) from the starting seed ID. If the AST contains any macro calls, its * IDs are also normalized. */ - private static CelAbstractSyntaxTree stabilizeAst(CelAbstractSyntaxTree ast, long seedExprId) { + private CelAbstractSyntaxTree stabilizeAst(CelAbstractSyntaxTree ast, long seedExprId) { StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(seedExprId); CelExpr.Builder newExprBuilder = @@ -373,7 +449,7 @@ private static CelAbstractSyntaxTree stabilizeAst(CelAbstractSyntaxTree ast, lon return CelAbstractSyntaxTree.newParsedAst(newExprBuilder.build(), sourceBuilder.build()); } - private static CelSource normalizeMacroSource( + private CelSource normalizeMacroSource( CelSource celSource, long exprIdToReplace, CelExpr.Builder mutatedRoot, @@ -424,8 +500,7 @@ private static CelSource normalizeMacroSource( CelExpr mutatedExpr = allExprs.get(callChild.id()); if (!callChild.equals(mutatedExpr)) { - newCall = - replaceSubtreeImpl((arg) -> arg, newCall, mutatedExpr.toBuilder(), callChild.id()); + newCall = mutateExpr((arg) -> arg, newCall, mutatedExpr.toBuilder(), callChild.id()); } } sourceBuilder.addMacroCalls(callId, newCall.build()); @@ -441,7 +516,7 @@ private static CelSource normalizeMacroSource( .forEach( node -> { CelExpr.Builder mutatedNode = - replaceSubtreeImpl( + mutateExpr( (id) -> id, macroCallExpr.toBuilder(), CelExpr.ofNotSet(node.id()).toBuilder(), @@ -453,18 +528,19 @@ private static CelSource normalizeMacroSource( return sourceBuilder.build(); } - private static CelExpr.Builder replaceSubtreeImpl( + private CelExpr.Builder mutateExpr( ExprIdGenerator idGenerator, CelExpr.Builder root, CelExpr.Builder newExpr, long exprIdToReplace) { - MutableAst mutableAst = new MutableAst(idGenerator, newExpr, exprIdToReplace); + MutableExprVisitor mutableAst = + MutableExprVisitor.newInstance(idGenerator, newExpr, exprIdToReplace, iterationLimit); return mutableAst.visit(root); } - private static CelExpr.Builder renumberExprIds( - ExprIdGenerator idGenerator, CelExpr.Builder root) { - MutableAst mutableAst = new MutableAst(idGenerator, root, Integer.MIN_VALUE); + private CelExpr.Builder renumberExprIds(ExprIdGenerator idGenerator, CelExpr.Builder root) { + MutableExprVisitor mutableAst = + MutableExprVisitor.newInstance(idGenerator, root, Integer.MIN_VALUE, iterationLimit); return mutableAst.visit(root); } @@ -498,105 +574,6 @@ private static int countComprehensionNestingLevel(CelNavigableExpr comprehension return nestedLevel; } - private CelExpr.Builder visit(CelExpr.Builder expr) { - if (++iterationCount > MAX_ITERATION_COUNT) { - throw new IllegalStateException("Max iteration count reached."); - } - - if (expr.id() == exprIdToReplace) { - exprIdToReplace = Integer.MIN_VALUE; // Marks that the subtree has been replaced. - return visit(newExpr.setId(expr.id())); - } - - expr.setId(celExprIdGenerator.generate(expr.id())); - - switch (expr.exprKind().getKind()) { - case SELECT: - return visit(expr, expr.select().toBuilder()); - case CALL: - return visit(expr, expr.call().toBuilder()); - case CREATE_LIST: - return visit(expr, expr.createList().toBuilder()); - case CREATE_STRUCT: - return visit(expr, expr.createStruct().toBuilder()); - case CREATE_MAP: - return visit(expr, expr.createMap().toBuilder()); - case COMPREHENSION: - return visit(expr, expr.comprehension().toBuilder()); - case CONSTANT: // Fall-through is intended - case IDENT: - case NOT_SET: // Note: comprehension arguments can contain a not set expr. - return expr; - default: - throw new IllegalArgumentException("unexpected expr kind: " + expr.exprKind().getKind()); - } - } - - private CelExpr.Builder visit(CelExpr.Builder expr, CelSelect.Builder select) { - select.setOperand(visit(select.operand().toBuilder()).build()); - return expr.setSelect(select.build()); - } - - private CelExpr.Builder visit(CelExpr.Builder expr, CelCall.Builder call) { - if (call.target().isPresent()) { - call.setTarget(visit(call.target().get().toBuilder()).build()); - } - ImmutableList argsBuilders = call.getArgsBuilders(); - for (int i = 0; i < argsBuilders.size(); i++) { - CelExpr.Builder arg = argsBuilders.get(i); - call.setArg(i, visit(arg).build()); - } - - return expr.setCall(call.build()); - } - - private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateStruct.Builder createStruct) { - ImmutableList entries = createStruct.getEntriesBuilders(); - for (int i = 0; i < entries.size(); i++) { - CelCreateStruct.Entry.Builder entry = entries.get(i); - entry.setId(celExprIdGenerator.generate(entry.id())); - entry.setValue(visit(entry.value().toBuilder()).build()); - - createStruct.setEntry(i, entry.build()); - } - - return expr.setCreateStruct(createStruct.build()); - } - - private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateMap.Builder createMap) { - ImmutableList entriesBuilders = createMap.getEntriesBuilders(); - for (int i = 0; i < entriesBuilders.size(); i++) { - CelCreateMap.Entry.Builder entry = entriesBuilders.get(i); - entry.setId(celExprIdGenerator.generate(entry.id())); - entry.setKey(visit(entry.key().toBuilder()).build()); - entry.setValue(visit(entry.value().toBuilder()).build()); - - createMap.setEntry(i, entry.build()); - } - - return expr.setCreateMap(createMap.build()); - } - - private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateList.Builder createList) { - ImmutableList elementsBuilders = createList.getElementsBuilders(); - for (int i = 0; i < elementsBuilders.size(); i++) { - CelExpr.Builder elem = elementsBuilders.get(i); - createList.setElement(i, visit(elem).build()); - } - - return expr.setCreateList(createList.build()); - } - - private CelExpr.Builder visit(CelExpr.Builder expr, CelComprehension.Builder comprehension) { - comprehension.setIterRange(visit(comprehension.iterRange().toBuilder()).build()); - comprehension.setAccuInit(visit(comprehension.accuInit().toBuilder()).build()); - comprehension.setLoopCondition(visit(comprehension.loopCondition().toBuilder()).build()); - comprehension.setLoopStep(visit(comprehension.loopStep().toBuilder()).build()); - comprehension.setResult(visit(comprehension.result().toBuilder()).build()); - - return expr.setComprehension(comprehension.build()); - } - /** * Intermediate value class to store the generated CelExpr for the bind macro and the macro call * information. diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableExprVisitor.java b/optimizer/src/main/java/dev/cel/optimizer/MutableExprVisitor.java new file mode 100644 index 00000000..36cb8606 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableExprVisitor.java @@ -0,0 +1,166 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import dev.cel.common.annotations.Internal; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelComprehension; +import dev.cel.common.ast.CelExpr.CelCreateList; +import dev.cel.common.ast.CelExpr.CelCreateMap; +import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator; + +/** + * MutableExprVisitor performs mutation of {@link CelExpr} based on its configured parameters. + * + *

This class is NOT thread-safe. Callers should spawn a new instance of this class each time the + * expression is being mutated. + * + *

Note that CelExpr is immutable by design. Therefore, the logic here doesn't actually mutate + * the existing expression tree. Instead, a brand new CelExpr is produced with the subtree swapped + * at the desired expression ID to replace. + */ +@Internal +final class MutableExprVisitor { + private final CelExpr.Builder newExpr; + private final ExprIdGenerator celExprIdGenerator; + private final long iterationLimit; + private int iterationCount; + private long exprIdToReplace; + + static MutableExprVisitor newInstance( + ExprIdGenerator idGenerator, + CelExpr.Builder newExpr, + long exprIdToReplace, + long iterationLimit) { + // iterationLimit * 2, because the expr can be walked twice due to the immutable nature of + // CelExpr. + return new MutableExprVisitor(idGenerator, newExpr, exprIdToReplace, iterationLimit * 2); + } + + CelExpr.Builder visit(CelExpr.Builder root) { + if (++iterationCount > iterationLimit) { + throw new IllegalStateException("Max iteration count reached."); + } + + if (root.id() == exprIdToReplace) { + exprIdToReplace = Integer.MIN_VALUE; // Marks that the subtree has been replaced. + return visit(this.newExpr.setId(root.id())); + } + + root.setId(celExprIdGenerator.generate(root.id())); + + switch (root.exprKind().getKind()) { + case SELECT: + return visit(root, root.select().toBuilder()); + case CALL: + return visit(root, root.call().toBuilder()); + case CREATE_LIST: + return visit(root, root.createList().toBuilder()); + case CREATE_STRUCT: + return visit(root, root.createStruct().toBuilder()); + case CREATE_MAP: + return visit(root, root.createMap().toBuilder()); + case COMPREHENSION: + return visit(root, root.comprehension().toBuilder()); + case CONSTANT: // Fall-through is intended + case IDENT: + case NOT_SET: // Note: comprehension arguments can contain a not set root. + return root; + } + throw new IllegalArgumentException("unexpected root kind: " + root.exprKind().getKind()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelSelect.Builder select) { + select.setOperand(visit(select.operand().toBuilder()).build()); + return expr.setSelect(select.build()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelCall.Builder call) { + if (call.target().isPresent()) { + call.setTarget(visit(call.target().get().toBuilder()).build()); + } + ImmutableList argsBuilders = call.getArgsBuilders(); + for (int i = 0; i < argsBuilders.size(); i++) { + CelExpr.Builder arg = argsBuilders.get(i); + call.setArg(i, visit(arg).build()); + } + + return expr.setCall(call.build()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateStruct.Builder createStruct) { + ImmutableList entries = createStruct.getEntriesBuilders(); + for (int i = 0; i < entries.size(); i++) { + CelCreateStruct.Entry.Builder entry = entries.get(i); + entry.setId(celExprIdGenerator.generate(entry.id())); + entry.setValue(visit(entry.value().toBuilder()).build()); + + createStruct.setEntry(i, entry.build()); + } + + return expr.setCreateStruct(createStruct.build()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateMap.Builder createMap) { + ImmutableList entriesBuilders = createMap.getEntriesBuilders(); + for (int i = 0; i < entriesBuilders.size(); i++) { + CelCreateMap.Entry.Builder entry = entriesBuilders.get(i); + entry.setId(celExprIdGenerator.generate(entry.id())); + entry.setKey(visit(entry.key().toBuilder()).build()); + entry.setValue(visit(entry.value().toBuilder()).build()); + + createMap.setEntry(i, entry.build()); + } + + return expr.setCreateMap(createMap.build()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateList.Builder createList) { + ImmutableList elementsBuilders = createList.getElementsBuilders(); + for (int i = 0; i < elementsBuilders.size(); i++) { + CelExpr.Builder elem = elementsBuilders.get(i); + createList.setElement(i, visit(elem).build()); + } + + return expr.setCreateList(createList.build()); + } + + private CelExpr.Builder visit(CelExpr.Builder expr, CelComprehension.Builder comprehension) { + comprehension.setIterRange(visit(comprehension.iterRange().toBuilder()).build()); + comprehension.setAccuInit(visit(comprehension.accuInit().toBuilder()).build()); + comprehension.setLoopCondition(visit(comprehension.loopCondition().toBuilder()).build()); + comprehension.setLoopStep(visit(comprehension.loopStep().toBuilder()).build()); + comprehension.setResult(visit(comprehension.result().toBuilder()).build()); + + return expr.setComprehension(comprehension.build()); + } + + private MutableExprVisitor( + ExprIdGenerator celExprIdGenerator, + CelExpr.Builder newExpr, + long exprId, + long iterationLimit) { + Preconditions.checkState(iterationLimit > 0L); + this.iterationLimit = iterationLimit; + this.celExprIdGenerator = celExprIdGenerator; + this.newExpr = newExpr; + this.exprIdToReplace = exprId; + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index f2519c83..2dcd3b2d 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -21,6 +21,7 @@ java_library( "//common/ast:expr_util", "//common/navigation", "//optimizer:ast_optimizer", + "//optimizer:mutable_ast", "//optimizer:optimization_exception", "//parser:operator", "//runtime", @@ -43,6 +44,7 @@ java_library( "//common/ast", "//common/navigation", "//optimizer:ast_optimizer", + "//optimizer:mutable_ast", "//parser:operator", "@maven//:com_google_guava_guava", "@maven//:org_jspecify_jspecify", diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index 272cd54f..ed3eb909 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -32,6 +32,7 @@ import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.optimizer.CelAstOptimizer; import dev.cel.optimizer.CelOptimizationException; +import dev.cel.optimizer.MutableAst; import dev.cel.parser.Operator; import dev.cel.runtime.CelEvaluationException; import java.util.Collection; @@ -46,9 +47,10 @@ * calls and select statements with their evaluated result. */ public final class ConstantFoldingOptimizer implements CelAstOptimizer { - public static final ConstantFoldingOptimizer INSTANCE = new ConstantFoldingOptimizer(); private static final int MAX_ITERATION_COUNT = 400; + public static final ConstantFoldingOptimizer INSTANCE = new ConstantFoldingOptimizer(); + // Use optional.of and optional.none as sentinel function names for folding optional calls. // TODO: Leverage CelValue representation of Optionals instead when available. private static final String OPTIONAL_OF_FUNCTION = "optional.of"; @@ -56,6 +58,8 @@ public final class ConstantFoldingOptimizer implements CelAstOptimizer { private static final CelExpr OPTIONAL_NONE_EXPR = CelExpr.ofCallExpr(0, Optional.empty(), OPTIONAL_NONE_FUNCTION, ImmutableList.of()); + private final MutableAst mutableAst; + @Override public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) throws CelOptimizationException { @@ -99,7 +103,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) // If the output is a list, map, or struct which contains optional entries, then prune it // to make sure that the optionals, if resolved, do not surface in the output literal. CelAbstractSyntaxTree newAst = pruneOptionalElements(navigableAst); - return renumberIdsConsecutively(newAst); + return mutableAst.renumberIdsConsecutively(newAst); } private static boolean canFold(CelNavigableExpr navigableExpr) { @@ -194,7 +198,7 @@ private Optional maybeFold( } return maybeAdaptEvaluatedResult(result) - .map(celExpr -> replaceSubtree(ast, celExpr, expr.id())); + .map(celExpr -> mutableAst.replaceSubtree(ast, celExpr, expr.id())); } private Optional maybeAdaptEvaluatedResult(Object result) { @@ -247,7 +251,7 @@ private Optional maybeRewriteOptional( // An empty optional value was encountered. Rewrite the tree with optional.none call. // This is to account for other optional functions returning an empty optional value // e.g: optional.ofNonZeroValue(0) - return Optional.of(replaceSubtree(ast, OPTIONAL_NONE_EXPR, expr.id())); + return Optional.of(mutableAst.replaceSubtree(ast, OPTIONAL_NONE_EXPR, expr.id())); } } else if (!expr.callOrDefault().function().equals(OPTIONAL_OF_FUNCTION)) { Object unwrappedResult = optResult.get(); @@ -267,7 +271,7 @@ private Optional maybeRewriteOptional( .build()) .build()) .build(); - return Optional.of(replaceSubtree(ast, newOptionalOfCall, expr.id())); + return Optional.of(mutableAst.replaceSubtree(ast, newOptionalOfCall, expr.id())); } return Optional.empty(); @@ -296,7 +300,7 @@ private Optional maybePruneBranches( } CelExpr result = cond.constant().booleanValue() ? truthy : falsy; - return Optional.of(replaceSubtree(ast, result, expr.id())); + return Optional.of(mutableAst.replaceSubtree(ast, result, expr.id())); } else if (function.equals(Operator.IN.getFunction())) { CelExpr callArg = call.args().get(1); if (!callArg.exprKind().getKind().equals(Kind.CREATE_LIST)) { @@ -306,7 +310,7 @@ private Optional maybePruneBranches( CelCreateList haystack = callArg.createList(); if (haystack.elements().isEmpty()) { return Optional.of( - replaceSubtree( + mutableAst.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(false)).build(), expr.id())); @@ -321,7 +325,7 @@ private Optional maybePruneBranches( if (elem.constantOrDefault().equals(needleValue) || elem.identOrDefault().equals(needleValue)) { return Optional.of( - replaceSubtree( + mutableAst.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), expr.id())); @@ -354,16 +358,16 @@ private Optional maybeShortCircuitCall( } if (arg.constant().booleanValue() == shortCircuit) { - return Optional.of(replaceSubtree(ast, arg, expr.id())); + return Optional.of(mutableAst.replaceSubtree(ast, arg, expr.id())); } } ImmutableList newArgs = newArgsBuilder.build(); if (newArgs.isEmpty()) { - return Optional.of(replaceSubtree(ast, call.args().get(0), expr.id())); + return Optional.of(mutableAst.replaceSubtree(ast, call.args().get(0), expr.id())); } if (newArgs.size() == 1) { - return Optional.of(replaceSubtree(ast, newArgs.get(0), expr.id())); + return Optional.of(mutableAst.replaceSubtree(ast, newArgs.get(0), expr.id())); } // TODO: Support folding variadic AND/ORs. @@ -441,7 +445,7 @@ private CelAbstractSyntaxTree pruneOptionalListElements(CelAbstractSyntaxTree as updatedIndicesBuilder.add(newOptIndex); } - return replaceSubtree( + return mutableAst.replaceSubtree( ast, CelExpr.newBuilder() .setCreateList( @@ -488,7 +492,7 @@ private CelAbstractSyntaxTree pruneOptionalMapElements(CelAbstractSyntaxTree ast } if (modified) { - return replaceSubtree( + return mutableAst.replaceSubtree( ast, CelExpr.newBuilder() .setCreateMap( @@ -534,7 +538,7 @@ private CelAbstractSyntaxTree pruneOptionalStructElements( } if (modified) { - return replaceSubtree( + return mutableAst.replaceSubtree( ast, CelExpr.newBuilder() .setCreateStruct( @@ -549,5 +553,7 @@ private CelAbstractSyntaxTree pruneOptionalStructElements( return ast; } - private ConstantFoldingOptimizer() {} + private ConstantFoldingOptimizer() { + this.mutableAst = MutableAst.newInstance(MAX_ITERATION_COUNT); + } } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java index d7eda5b4..8d057c74 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java @@ -33,6 +33,7 @@ import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.common.navigation.CelNavigableExpr.TraversalOrder; import dev.cel.optimizer.CelAstOptimizer; +import dev.cel.optimizer.MutableAst; import dev.cel.parser.Operator; import java.util.HashMap; import java.util.HashSet; @@ -59,7 +60,6 @@ * */ public class SubexpressionOptimizer implements CelAstOptimizer { - private static final SubexpressionOptimizer INSTANCE = new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build()); private static final String BIND_IDENTIFIER_PREFIX = "@r"; @@ -70,6 +70,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer { stream(Standard.Function.values()).map(Standard.Function::getFunction)) .collect(toImmutableSet()); private final SubexpressionOptimizerOptions cseOptions; + private final MutableAst mutableAst; /** * Returns a default instance of common subexpression elimination optimizer with preconfigured @@ -90,13 +91,13 @@ public static SubexpressionOptimizer newInstance(SubexpressionOptimizerOptions c @Override public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { CelAbstractSyntaxTree astToModify = - mangleComprehensionIdentifierNames( + mutableAst.mangleComprehensionIdentifierNames( navigableAst.getAst(), MANGLED_COMPREHENSION_IDENTIFIER_PREFIX); CelSource sourceToModify = astToModify.getSource(); int bindIdentifierIndex = 0; int iterCount; - for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) { + for (iterCount = 0; iterCount < cseOptions.iterationLimit(); iterCount++) { CelExpr cseCandidate = findCseCandidate(astToModify).map(CelNavigableExpr::expr).orElse(null); if (cseCandidate == null) { break; @@ -122,7 +123,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { "No value present for expr ID: " + semanticallyEqualNode.id())); astToModify = - replaceSubtree( + mutableAst.replaceSubtree( astToModify, CelExpr.newBuilder() .setIdent(CelIdent.newBuilder().setName(bindIdentifier).build()) @@ -141,7 +142,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { // Insert the new bind call astToModify = - replaceSubtreeWithNewBindMacro( + mutableAst.replaceSubtreeWithNewBindMacro( astToModify, bindIdentifier, cseCandidate, lca.expr(), lca.id()); // Retain the existing macro calls in case if the bind identifiers are replacing a subtree @@ -149,7 +150,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { sourceToModify = astToModify.getSource(); } - if (iterCount >= cseOptions.maxIterationLimit()) { + if (iterCount >= cseOptions.iterationLimit()) { throw new IllegalStateException("Max iteration count reached."); } @@ -163,7 +164,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel) { CelAbstractSyntaxTree.newParsedAst(astToModify.getExpr(), CelSource.newBuilder().build()); } - return renumberIdsConsecutively(astToModify); + return mutableAst.renumberIdsConsecutively(astToModify); } private Stream getAllCseCandidatesStream( @@ -297,7 +298,7 @@ private static boolean isAllowedFunction(CelNavigableExpr navigableExpr) { */ private CelExpr normalizeForEquality(CelExpr celExpr) { int iterCount; - for (iterCount = 0; iterCount < cseOptions.maxIterationLimit(); iterCount++) { + for (iterCount = 0; iterCount < cseOptions.iterationLimit(); iterCount++) { CelExpr presenceTestExpr = CelNavigableExpr.fromExpr(celExpr) .allNodes() @@ -314,20 +315,20 @@ private CelExpr normalizeForEquality(CelExpr celExpr) { .setSelect(presenceTestExpr.select().toBuilder().setTestOnly(false).build()) .build(); - celExpr = replaceSubtree(celExpr, newExpr, newExpr.id()); + celExpr = mutableAst.replaceSubtree(celExpr, newExpr, newExpr.id()); } - if (iterCount >= cseOptions.maxIterationLimit()) { + if (iterCount >= cseOptions.iterationLimit()) { throw new IllegalStateException("Max iteration count reached."); } - return clearExprIds(celExpr); + return mutableAst.clearExprIds(celExpr); } /** Options to configure how Common Subexpression Elimination behave. */ @AutoValue public abstract static class SubexpressionOptimizerOptions { - public abstract int maxIterationLimit(); + public abstract int iterationLimit(); public abstract boolean populateMacroCalls(); @@ -339,7 +340,7 @@ public abstract static class Builder { * Limit the number of iteration while performing CSE. An exception is thrown if the iteration * count exceeds the set value. */ - public abstract Builder maxIterationLimit(int value); + public abstract Builder iterationLimit(int value); /** * Populate the macro_calls map in source_info with macro calls on the resulting optimized @@ -355,7 +356,7 @@ public abstract static class Builder { /** Returns a new options builder with recommended defaults pre-configured. */ public static Builder newBuilder() { return new AutoValue_SubexpressionOptimizer_SubexpressionOptimizerOptions.Builder() - .maxIterationLimit(500) + .iterationLimit(500) .populateMacroCalls(false); } @@ -364,5 +365,6 @@ public static Builder newBuilder() { private SubexpressionOptimizer(SubexpressionOptimizerOptions cseOptions) { this.cseOptions = cseOptions; + this.mutableAst = MutableAst.newInstance(cseOptions.iterationLimit()); } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index 8f79838c..d027c6b8 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -16,6 +16,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.common.collect.ImmutableMap; import com.google.testing.junit.testparameterinjector.TestParameterInjector; @@ -61,13 +62,14 @@ public class MutableAstTest { .build(); private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + private static final MutableAst MUTABLE_AST = MutableAst.newInstance(1000); @Test public void constExpr() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("10").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); assertThat(mutatedAst.getExpr()) @@ -79,7 +81,7 @@ public void mutableAst_returnsParsedAst() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("10").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); assertThat(ast.isChecked()).isTrue(); @@ -91,7 +93,7 @@ public void mutableAst_nonMacro_sourceCleared() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("10").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); assertThat(mutatedAst.getSource().getDescription()).isEmpty(); @@ -105,7 +107,7 @@ public void mutableAst_macro_sourceMacroCallsPopulated() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("has(TestAllTypes{}.single_int32)").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); assertThat(mutatedAst.getSource().getDescription()).isEmpty(); @@ -126,7 +128,8 @@ public void replaceSubtree_rootReplacedWithMacro_macroCallPopulated( CelAbstractSyntaxTree ast2 = CEL.compile(source).getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().id()); + MUTABLE_AST.replaceSubtreeWithNewAst( + ast, ast2, CelNavigableAst.fromAst(ast).getRoot().id()); assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(expectedMacroCallSize); assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo(source); @@ -139,9 +142,9 @@ public void replaceSubtree_branchReplacedWithMacro_macroCallPopulated() throws E CelAbstractSyntaxTree ast2 = CEL.compile("[1].exists(x, x > 0)").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree(ast, ast2, 3); // Replace false with the macro expr + MUTABLE_AST.replaceSubtreeWithNewAst(ast, ast2, 3); // Replace false with the macro expr CelAbstractSyntaxTree mutatedAst2 = - MutableAst.replaceSubtree(ast, ast2, 1); // Replace true with the macro expr + MUTABLE_AST.replaceSubtreeWithNewAst(ast, ast2, 1); // Replace true with the macro expr assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(1); assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("true && [1].exists(x, x > 0)"); @@ -157,7 +160,7 @@ public void replaceSubtree_macroInsertedIntoExistingMacro_macroCallPopulated() t CelAbstractSyntaxTree ast2 = CEL.compile("[2].exists(y, y > 0)").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree(ast, ast2, 9); // Replace true with the ast2 maro expr + MUTABLE_AST.replaceSubtreeWithNewAst(ast, ast2, 9); // Replace true with the ast2 maro expr assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(2); assertThat(CEL_UNPARSER.unparse(mutatedAst)) @@ -180,7 +183,7 @@ public void replaceSubtreeWithNewBindMacro_replaceRoot() throws Exception { .build(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtreeWithNewBindMacro( + MUTABLE_AST.replaceSubtreeWithNewBindMacro( ast, variableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), @@ -212,7 +215,7 @@ public void replaceSubtreeWithNewBindMacro_nestedBindMacro_replaceComprehensionR // Act // Perform the initial replacement. (1 + 1) -> cel.bind(@r0, 3, @r0 + @r0) CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtreeWithNewBindMacro( + MUTABLE_AST.replaceSubtreeWithNewBindMacro( ast, variableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), @@ -261,7 +264,7 @@ public void replaceSubtreeWithNewBindMacro_nestedBindMacro_replaceComprehensionR .id(); // This should produce cel.bind(@r1, 1, cel.bind(@r0, 3, @r0 + @r0 + @r1 + @r1)) mutatedAst = - MutableAst.replaceSubtreeWithNewBindMacro( + MUTABLE_AST.replaceSubtreeWithNewBindMacro( mutatedAst, nestedVariableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(1L)), @@ -293,7 +296,7 @@ public void replaceSubtreeWithNewBindMacro_replaceRootWithNestedBindMacro() thro // Act // Perform the initial replacement. (1 + 1 + 3 + 3) -> cel.bind(@r0, 1, @r0 + @r0) + 3 + 3 CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtreeWithNewBindMacro( + MUTABLE_AST.replaceSubtreeWithNewBindMacro( ast, variableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(1L)), @@ -328,7 +331,7 @@ public void replaceSubtreeWithNewBindMacro_replaceRootWithNestedBindMacro() thro .build(); // Replace the root with the new result and a bind macro inserted mutatedAst = - MutableAst.replaceSubtreeWithNewBindMacro( + MUTABLE_AST.replaceSubtreeWithNewBindMacro( mutatedAst, nestedVariableName, CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), @@ -349,7 +352,8 @@ public void replaceSubtree_macroReplacedWithConstExpr_macroCallCleared() throws CelAbstractSyntaxTree ast2 = CEL.compile("1").getAst(); CelAbstractSyntaxTree mutatedAst = - MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().id()); + MUTABLE_AST.replaceSubtreeWithNewAst( + ast, ast2, CelNavigableAst.fromAst(ast).getRoot().id()); assertThat(mutatedAst.getSource().getMacroCalls()).isEmpty(); assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("1"); @@ -365,7 +369,7 @@ public void globalCallExpr_replaceRoot() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 4); assertThat(replacedAst.getExpr()).isEqualTo(CelExpr.ofConstantExpr(7, CelConstant.ofValue(10))); @@ -380,7 +384,7 @@ public void globalCallExpr_replaceLeaf() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 1); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10 + 2 + x"); @@ -395,7 +399,7 @@ public void globalCallExpr_replaceMiddleBranch() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 2); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10 + x"); @@ -410,7 +414,7 @@ public void globalCallExpr_replaceMiddleBranch_withCallExpr() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("1 + 2 + x").getAst(); CelAbstractSyntaxTree ast2 = CEL.compile("4 + 5 + 6").getAst(); - CelAbstractSyntaxTree replacedAst = MutableAst.replaceSubtree(ast, ast2.getExpr(), 2); + CelAbstractSyntaxTree replacedAst = MUTABLE_AST.replaceSubtree(ast, ast2.getExpr(), 2); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("4 + 5 + 6 + x"); } @@ -432,7 +436,7 @@ public void memberCallExpr_replaceLeafTarget() throws Exception { CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 3); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10.func(20.func(5))"); @@ -455,7 +459,7 @@ public void memberCallExpr_replaceLeafArgument() throws Exception { CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 5); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10.func(4.func(20))"); @@ -478,7 +482,7 @@ public void memberCallExpr_replaceMiddleBranchTarget() throws Exception { CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 1); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("20.func(4.func(5))"); @@ -501,7 +505,7 @@ public void memberCallExpr_replaceMiddleBranchArgument() throws Exception { CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 4); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("10.func(20)"); @@ -516,7 +520,7 @@ public void select_replaceField() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder() .setSelect( @@ -542,7 +546,7 @@ public void select_replaceOperand() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName("test").build()).build(), 3); @@ -558,7 +562,7 @@ public void list_replaceElement() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[2, 3, 4]").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[2, 3, 5]"); @@ -573,7 +577,7 @@ public void createStruct_replaceValue() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("TestAllTypes{single_int64: 2}").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("TestAllTypes{single_int64: 5}"); @@ -588,7 +592,7 @@ public void createMap_replaceKey() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("{5: 1}"); @@ -603,7 +607,7 @@ public void createMap_replaceValue() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("{\"a\": 5}"); @@ -614,7 +618,7 @@ public void comprehension_replaceIterRange() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[true].exists(i, i)").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(false)).build(), 2); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, i)"); @@ -627,7 +631,7 @@ public void comprehension_replaceAccuInit() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 6); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, i)"); @@ -643,7 +647,7 @@ public void comprehension_replaceLoopStep() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree( + MUTABLE_AST.replaceSubtree( ast, CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName("test").build()).build(), 5); @@ -656,7 +660,7 @@ public void comprehension_replaceLoopStep() throws Exception { public void mangleComprehensionVariable_singleMacro() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[false].exists(i, i)").getAst(); - CelAbstractSyntaxTree mangledAst = MutableAst.mangleComprehensionIdentifierNames(ast, "@c"); + CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c"); assertThat(mangledAst.getExpr().toString()) .isEqualTo( @@ -716,7 +720,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception { public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("[x].exists(x, [x].exists(x, x == 1))").getAst(); - CelAbstractSyntaxTree mangledAst = MutableAst.mangleComprehensionIdentifierNames(ast, "@c"); + CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c"); assertThat(mangledAst.getExpr().toString()) .isEqualTo( @@ -833,7 +837,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw public void mangleComprehensionVariable_hasMacro_noOp() throws Exception { CelAbstractSyntaxTree ast = CEL.compile("has(msg.single_int64)").getAst(); - CelAbstractSyntaxTree mangledAst = MutableAst.mangleComprehensionIdentifierNames(ast, "@c"); + CelAbstractSyntaxTree mangledAst = MUTABLE_AST.mangleComprehensionIdentifierNames(ast, "@c"); assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("has(msg.single_int64)"); assertThat( @@ -843,6 +847,21 @@ public void mangleComprehensionVariable_hasMacro_noOp() throws Exception { assertConsistentMacroCalls(ast); } + @Test + public void replaceSubtree_iterationLimitReached_throws() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("true && false").getAst(); + MutableAst mutableAst = MutableAst.newInstance(1); + + IllegalStateException e = + assertThrows( + IllegalStateException.class, + () -> + mutableAst.replaceSubtree( + ast, CelExpr.ofConstantExpr(0, CelConstant.ofValue(false)), 1)); + + assertThat(e).hasMessageThat().isEqualTo("Max iteration count reached."); + } + /** * Asserts that the expressions that appears in source_info's macro calls are consistent with the * actual expr nodes in the AST. diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index eaa8c6ce..7b560adf 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -381,7 +381,7 @@ public void constantFold_astProducesConsistentlyNumberedIds() throws Exception { } @Test - public void maxIterationCountReached_throws() throws Exception { + public void iterationLimitReached_throws() throws Exception { StringBuilder sb = new StringBuilder(); sb.append("0"); for (int i = 1; i < 400; i++) { diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 98e1ff75..b0cd6486 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -634,12 +634,12 @@ public void cse_applyConstFoldingAfter() throws Exception { } @Test - public void maxIterationLimitReached_throws() throws Exception { + public void iterationLimitReached_throws() throws Exception { StringBuilder largeExprBuilder = new StringBuilder(); - int maxIterationLimit = 100; - for (int i = 0; i < maxIterationLimit; i++) { + int iterationLimit = 100; + for (int i = 0; i < iterationLimit; i++) { largeExprBuilder.append("[1,2]"); - if (i < maxIterationLimit - 1) { + if (i < iterationLimit - 1) { largeExprBuilder.append("+"); } } @@ -653,7 +653,7 @@ public void maxIterationLimitReached_throws() throws Exception { .addAstOptimizers( SubexpressionOptimizer.newInstance( SubexpressionOptimizerOptions.newBuilder() - .maxIterationLimit(maxIterationLimit) + .iterationLimit(iterationLimit) .build())) .build() .optimize(ast));