diff --git a/common/src/main/java/dev/cel/common/ast/BUILD.bazel b/common/src/main/java/dev/cel/common/ast/BUILD.bazel index c8663e6e..ec79dbed 100644 --- a/common/src/main/java/dev/cel/common/ast/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/ast/BUILD.bazel @@ -88,6 +88,7 @@ java_library( ], deps = [ ":ast", + "//common/annotations", "@maven//:com_google_guava_guava", "@maven//:com_google_protobuf_protobuf_java", ], diff --git a/common/src/main/java/dev/cel/common/ast/CelExpr.java b/common/src/main/java/dev/cel/common/ast/CelExpr.java index b075de5f..c106a1ff 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExpr.java +++ b/common/src/main/java/dev/cel/common/ast/CelExpr.java @@ -216,6 +216,7 @@ public CelComprehension comprehensionOrDefault() { /** Builder for CelExpr. */ @AutoValue.Builder public abstract static class Builder { + public abstract long id(); public abstract Builder setId(long value); @@ -787,6 +788,8 @@ public abstract static class Entry { @AutoValue.Builder public abstract static class Builder { + public abstract long id(); + public abstract CelExpr value(); public abstract Builder setId(long value); @@ -918,6 +921,7 @@ public abstract static class Entry { /** Builder for CelCreateMap.Entry. */ @AutoValue.Builder public abstract static class Builder { + public abstract long id(); public abstract CelExpr key(); diff --git a/common/src/main/java/dev/cel/common/ast/CelExprFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprFactory.java index f73cb056..385e72fe 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprFactory.java @@ -19,16 +19,24 @@ import com.google.common.primitives.UnsignedLong; import com.google.protobuf.ByteString; +import dev.cel.common.annotations.Internal; import java.util.Arrays; /** Factory for generating expression nodes. */ +@Internal public class CelExprFactory { - private final CelExprIdGeneratorFactory.MonotonicIdGenerator idGenerator; + + private final CelExprIdGeneratorFactory.ExprIdGenerator idGenerator; public static CelExprFactory newInstance() { return new CelExprFactory(); } + public static CelExprFactory newInstance( + CelExprIdGeneratorFactory.ExprIdGenerator exprIdGenerator) { + return new CelExprFactory(exprIdGenerator); + } + /** Create a new constant expression. */ public final CelExpr newConstant(CelConstant constant) { return CelExpr.newBuilder().setId(nextExprId()).setConstant(constant).build(); @@ -543,10 +551,19 @@ public final CelExpr newSelect(CelExpr operand, String field, boolean testOnly) /** Returns the next unique expression ID. */ protected long nextExprId() { - return idGenerator.nextExprId(); + return idGenerator.generate( + /* exprId= */ -1); // Unconditionally generate next unique ID (i.e: no renumbering). } protected CelExprFactory() { - idGenerator = CelExprIdGeneratorFactory.newMonotonicIdGenerator(0); + this(CelExprIdGeneratorFactory.newMonotonicIdGenerator(0)); + } + + private CelExprFactory(CelExprIdGeneratorFactory.MonotonicIdGenerator idGenerator) { + this((unused) -> idGenerator.nextExprId()); + } + + private CelExprFactory(CelExprIdGeneratorFactory.ExprIdGenerator exprIdGenerator) { + idGenerator = exprIdGenerator; } } diff --git a/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java b/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java index 4c8add1f..8ab71c09 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprFormatter.java @@ -14,11 +14,17 @@ package dev.cel.common.ast; +import com.google.common.collect.ImmutableSet; + /** Provides string formatting support for {@link CelExpr}. */ final class CelExprFormatter { private final StringBuilder indent = new StringBuilder(); private final StringBuilder exprBuilder = new StringBuilder(); + /** Denotes a set of expression kinds that will not have a new line inserted. */ + private static final ImmutableSet EXCLUDED_NEWLINE_KINDS = + ImmutableSet.of(CelExpr.ExprKind.Kind.CONSTANT, CelExpr.ExprKind.Kind.NOT_SET); + static String format(CelExpr celExpr) { CelExprFormatter formatter = new CelExprFormatter(); formatter.formatExpr(celExpr); @@ -28,7 +34,7 @@ static String format(CelExpr celExpr) { private void formatExpr(CelExpr celExpr) { append(String.format("%s [%d] {", celExpr.exprKind().getKind(), celExpr.id())); CelExpr.ExprKind.Kind exprKind = celExpr.exprKind().getKind(); - if (!exprKind.equals(CelExpr.ExprKind.Kind.CONSTANT)) { + if (!EXCLUDED_NEWLINE_KINDS.contains(exprKind)) { appendNewline(); } @@ -57,12 +63,17 @@ private void formatExpr(CelExpr celExpr) { case COMPREHENSION: appendComprehension(celExpr.comprehension()); break; + case NOT_SET: + break; default: + // This should be unreachable unless if we've added any other kinds. + indent(); append("Unknown kind: " + exprKind); + outdent(); break; } - if (!exprKind.equals(CelExpr.ExprKind.Kind.CONSTANT)) { + if (!EXCLUDED_NEWLINE_KINDS.contains(exprKind)) { appendNewline(); append("}"); } else { diff --git a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java index c434150c..10e259c0 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java @@ -59,8 +59,45 @@ public static class StableIdGenerator { private final HashMap idSet; private long exprId; + /** Checks if the given ID has been encountered before. */ + public boolean hasId(long id) { + return idSet.containsKey(id); + } + + /** + * Generate the next available ID while memoizing the existing ID. + * + *

The main purpose of this is to sanitize a new AST to replace an existing AST's node with. + * The incoming AST may not have its IDs consistently numbered (often, the expr IDs are just + * zeroes). In those cases, we just want to return an incremented expr ID. + * + *

The memoization becomes necessary if the incoming AST contains an expression with macro + * map populated, requiring a normalization pass. In this case, the method behaves largely the + * same as {@link #renumberId}. + * + * @param id Existing ID to memoize. Providing 0 or less will skip the memoization, in which + * case this behaves just like a {@link MonotonicIdGenerator}. + */ + public long nextExprId(long id) { + long nextExprId = ++exprId; + if (id > 0) { + idSet.put(id, nextExprId); + } + return nextExprId; + } + + /** Memoize a given expression ID with a newly generated ID. */ + public void memoize(long existingId, long newId) { + idSet.put(existingId, newId); + } + + /** + * Renumbers the existing expression ID to a newly generated unique ID. The existing ID is + * memoized, and calling this method again with the same ID will always return the same + * generated ID. + */ public long renumberId(long id) { - Preconditions.checkArgument(id >= 0); + Preconditions.checkArgument(id >= 0, "Expr ID must be positive. Got: %s", id); if (id == 0) { return 0; } @@ -81,5 +118,13 @@ private StableIdGenerator(long exprId) { } } + /** Functional interface for generating the next unique expression ID. */ + @FunctionalInterface + public interface ExprIdGenerator { + + /** Generates an expression ID with the provided expr ID as the context. */ + long generate(long exprId); + } + private CelExprIdGeneratorFactory() {} } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java b/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java index f3d77ea9..30e8f5e4 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprFactoryTest.java @@ -16,6 +16,7 @@ import static com.google.common.truth.Truth.assertThat; +import dev.cel.common.ast.CelExprIdGeneratorFactory.StableIdGenerator; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,4 +38,15 @@ public void nextExprId_startingDefaultIsOne() { assertThat(exprFactory.nextExprId()).isEqualTo(1L); assertThat(exprFactory.nextExprId()).isEqualTo(2L); } + + @Test + public void nextExprId_usingStableIdGenerator() { + StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); + CelExprFactory exprFactory = CelExprFactory.newInstance(stableIdGenerator::nextExprId); + + assertThat(exprFactory.nextExprId()).isEqualTo(1L); + assertThat(exprFactory.nextExprId()).isEqualTo(2L); + assertThat(stableIdGenerator.hasId(-1)).isFalse(); + assertThat(stableIdGenerator.hasId(0)).isFalse(); + } } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java b/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java index b80779bf..2c863b1f 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprFormatterTest.java @@ -61,6 +61,13 @@ public void constant(@TestParameter ConstantTestCase constantTestCase) throws Ex assertThat(formattedExpr).isEqualTo(constantTestCase.formatted); } + @Test + public void notSet() { + String formattedExpr = CelExprFormatter.format(CelExpr.ofNotSet(1)); + + assertThat(formattedExpr).isEqualTo("NOT_SET [1] {}"); + } + @Test public void select() throws Exception { CelCompiler celCompiler = diff --git a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel index e145a208..454ab828 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel @@ -81,6 +81,7 @@ java_library( tags = [ ], deps = [ + "//:auto_value", "//common", "//common/annotations", "//common/ast", diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java index 64f5913d..c8bf99c7 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -16,7 +16,9 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.Math.max; +import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import dev.cel.common.CelAbstractSyntaxTree; @@ -29,8 +31,9 @@ import dev.cel.common.ast.CelExpr.CelCreateMap; import dev.cel.common.ast.CelExpr.CelCreateStruct; import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.ast.CelExprFactory; import dev.cel.common.ast.CelExprIdGeneratorFactory; -import dev.cel.common.ast.CelExprIdGeneratorFactory.MonotonicIdGenerator; +import dev.cel.common.ast.CelExprIdGeneratorFactory.ExprIdGenerator; import dev.cel.common.ast.CelExprIdGeneratorFactory.StableIdGenerator; import dev.cel.common.navigation.CelNavigableExpr; import java.util.Map.Entry; @@ -51,34 +54,172 @@ private MutableAst(ExprIdGenerator celExprIdGenerator, CelExpr.Builder newExpr, this.exprIdToReplace = exprId; } + /** Replaces all the expression IDs in the expression tree with 0. */ + static CelExpr clearExprIds(CelExpr celExpr) { + return renumberExprIds((unused) -> 0, celExpr.toBuilder()).build(); + } + /** - * Replaces a subtree in the given CelExpr. + * Mutates the given AST by replacing a subtree at a given index. * - *

This method should remain package-private. + * @param ast Existing AST being mutated + * @param newExpr New subtree to perform the replacement with. + * @param exprIdToReplace The expr ID in the existing AST to replace the subtree at. */ static CelAbstractSyntaxTree replaceSubtree( CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { - // Update the IDs in the new expression tree first. This ensures that no ID collision - // occurs while attempting to replace the subtree, potentially leading to infinite loop - MonotonicIdGenerator monotonicIdGenerator = - CelExprIdGeneratorFactory.newMonotonicIdGenerator(getMaxId(ast.getExpr())); - CelExpr.Builder newExprBuilder = - renumberExprIds((unused) -> monotonicIdGenerator.nextExprId(), newExpr.toBuilder()); + return replaceSubtree( + ast, + CelAbstractSyntaxTree.newParsedAst(newExpr, CelSource.newBuilder().build()), + exprIdToReplace); + } + /** + * Mutates the given AST by replacing a subtree at a given index. + * + * @param ast Existing AST being mutated + * @param newAst New subtree to perform the replacement with. If the subtree has a macro map + * populated, its macro source is merged with the existing AST's after normalization. + * @param exprIdToReplace The expr ID in the existing AST to replace the subtree at. + */ + static CelAbstractSyntaxTree replaceSubtree( + CelAbstractSyntaxTree ast, CelAbstractSyntaxTree newAst, long exprIdToReplace) { + // Stabilize the incoming AST by renumbering all of its expression IDs. + long maxId = max(getMaxId(ast), getMaxId(newAst)); + newAst = stabilizeAst(newAst, maxId); + + // Mutate the AST root with the new subtree. All the existing expr IDs are renumbered in the + // process, but its original IDs are memoized so that we can normalize the expr IDs + // in the macro source map. StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(0); CelExpr.Builder mutatedRoot = replaceSubtreeImpl( stableIdGenerator::renumberId, ast.getExpr().toBuilder(), - newExprBuilder, + newAst.getExpr().toBuilder(), exprIdToReplace); - // If the source info contained macro call information, their IDs must be normalized. - CelSource normalizedSource = + CelSource newAstSource = ast.getSource(); + if (!newAst.getSource().getMacroCalls().isEmpty()) { + // The root is mutated, but the expr IDs in the macro map needs to be normalized. + // In situations where an AST with a new macro map is being inserted (ex: new bind call), + // the new subtree's expr ID is not memoized in the stable ID generator because the ID never + // existed in the main AST. + // In this case, we forcibly memoize the new subtree ID with a newly generated ID so + // that the macro map IDs can be normalized properly. + stableIdGenerator.memoize( + newAst.getExpr().id(), stableIdGenerator.renumberId(exprIdToReplace)); + newAstSource = combine(newAstSource, newAst.getSource()); + } + + newAstSource = + normalizeMacroSource( + newAstSource, exprIdToReplace, mutatedRoot, stableIdGenerator::renumberId); + + return CelAbstractSyntaxTree.newParsedAst(mutatedRoot.build(), newAstSource); + } + + /** Replaces the subtree at the given ID with a newly created bind macro. */ + static CelAbstractSyntaxTree replaceSubtreeWithNewBindMacro( + CelAbstractSyntaxTree ast, + String varName, + CelExpr varInit, + CelExpr resultExpr, + long exprIdToReplace) { + long maxId = max(getMaxId(varInit), getMaxId(ast)); + StableIdGenerator stableIdGenerator = CelExprIdGeneratorFactory.newStableIdGenerator(maxId); + BindMacro bindMacro = newBindMacro(varName, varInit, resultExpr, stableIdGenerator); + // In situations where the existing AST already contains a macro call (ex: nested cel.binds), + // its macro source must be normalized to make it consistent with the newly generated bind + // macro. + CelSource celSource = normalizeMacroSource( - ast.getSource(), exprIdToReplace, mutatedRoot, stableIdGenerator::renumberId); + ast.getSource(), + -1, // Do not replace any of the subexpr in the macro map. + bindMacro.bindMacro().toBuilder(), + stableIdGenerator::renumberId); + celSource = + celSource.toBuilder() + .addMacroCalls(bindMacro.bindExpr().id(), bindMacro.bindMacro()) + .build(); + + return replaceSubtree( + ast, CelAbstractSyntaxTree.newParsedAst(bindMacro.bindExpr(), celSource), exprIdToReplace); + } + + private static BindMacro newBindMacro( + String varName, CelExpr varInit, CelExpr resultExpr, StableIdGenerator stableIdGenerator) { + // Clear incoming expression IDs in the initialization expression to avoid collision with the + // main AST. + varInit = clearExprIds(varInit); + CelExprFactory exprFactory = + CelExprFactory.newInstance((unused) -> stableIdGenerator.nextExprId(-1)); + CelExpr bindMacroExpr = + exprFactory.fold( + "#unused", + exprFactory.newList(), + varName, + varInit, + exprFactory.newBoolLiteral(false), + exprFactory.newIdentifier(varName), + resultExpr); - return CelAbstractSyntaxTree.newParsedAst(mutatedRoot.build(), normalizedSource); + // Update the IDs in the new expression tree first. This ensures that no ID collision + // occurs while attempting to replace the subtree later, potentially leading to an infinite loop + bindMacroExpr = + renumberExprIds(stableIdGenerator::nextExprId, bindMacroExpr.toBuilder()).build(); + + CelExpr bindMacroCallExpr = + exprFactory + .newReceiverCall( + "bind", + CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(-1), "cel"), + CelExpr.ofIdentExpr(stableIdGenerator.nextExprId(-1), varName), + bindMacroExpr.comprehension().accuInit(), + bindMacroExpr.comprehension().result()) + .toBuilder() + .setId(0) + .build(); + + return BindMacro.of(bindMacroExpr, bindMacroCallExpr); + } + + private static CelSource combine(CelSource celSource1, CelSource celSource2) { + ImmutableMap.Builder macroMap = ImmutableMap.builder(); + macroMap.putAll(celSource1.getMacroCalls()); + macroMap.putAll(celSource2.getMacroCalls()); + + return CelSource.newBuilder().addAllMacroCalls(macroMap.buildOrThrow()).build(); + } + + /** + * Stabilizes the incoming AST by ensuring that all of expr IDs are consistently renumbered + * (monotonically increased) from the starting seed ID. If the AST contains any macro calls, its + * IDs are also normalized. + */ + private static CelAbstractSyntaxTree stabilizeAst(CelAbstractSyntaxTree ast, long seedExprId) { + StableIdGenerator stableIdGenerator = + CelExprIdGeneratorFactory.newStableIdGenerator(seedExprId); + CelExpr.Builder newExprBuilder = + renumberExprIds(stableIdGenerator::nextExprId, ast.getExpr().toBuilder()); + + if (ast.getSource().getMacroCalls().isEmpty()) { + return CelAbstractSyntaxTree.newParsedAst(newExprBuilder.build(), ast.getSource()); + } + + CelSource.Builder sourceBuilder = CelSource.newBuilder(); + // Update the macro call IDs and their call IDs + for (Entry macroCall : ast.getSource().getMacroCalls().entrySet()) { + long macroId = macroCall.getKey(); + long newCallId = stableIdGenerator.renumberId(macroId); + + CelExpr.Builder newCall = + renumberExprIds(stableIdGenerator::renumberId, macroCall.getValue().toBuilder()); + + sourceBuilder.addMacroCalls(newCallId, newCall.build()); + } + + return CelAbstractSyntaxTree.newParsedAst(newExprBuilder.build(), sourceBuilder.build()); } private static CelSource normalizeMacroSource( @@ -116,6 +257,10 @@ private static CelSource normalizeMacroSource( long macroId = macroCall.getKey(); long callId = idGenerator.generate(macroId); + if (!allExprs.containsKey(callId)) { + continue; + } + CelExpr.Builder newCall = renumberExprIds(idGenerator, macroCall.getValue().toBuilder()); CelNavigableExpr callNav = CelNavigableExpr.fromExpr(newCall.build()); ImmutableList callDescendants = @@ -152,6 +297,15 @@ private static CelExpr.Builder renumberExprIds( return mutableAst.visit(root); } + private static long getMaxId(CelAbstractSyntaxTree ast) { + long maxId = getMaxId(ast.getExpr()); + for (Entry macroCall : ast.getSource().getMacroCalls().entrySet()) { + maxId = max(maxId, getMaxId(macroCall.getValue())); + } + + return maxId; + } + private static long getMaxId(CelExpr newExpr) { return CelNavigableExpr.fromExpr(newExpr) .allNodes() @@ -216,6 +370,7 @@ private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateStruct.Builder crea ImmutableList entries = createStruct.getEntriesBuilders(); for (int i = 0; i < entries.size(); i++) { CelCreateStruct.Entry.Builder entry = entries.get(i); + entry.setId(celExprIdGenerator.generate(entry.id())); entry.setValue(visit(entry.value().toBuilder()).build()); createStruct.setEntry(i, entry.build()); @@ -228,6 +383,7 @@ private CelExpr.Builder visit(CelExpr.Builder expr, CelCreateMap.Builder createM ImmutableList entriesBuilders = createMap.getEntriesBuilders(); for (int i = 0; i < entriesBuilders.size(); i++) { CelCreateMap.Entry.Builder entry = entriesBuilders.get(i); + entry.setId(celExprIdGenerator.generate(entry.id())); entry.setKey(visit(entry.key().toBuilder()).build()); entry.setValue(visit(entry.value().toBuilder()).build()); @@ -257,10 +413,24 @@ private CelExpr.Builder visit(CelExpr.Builder expr, CelComprehension.Builder com return expr.setComprehension(comprehension.build()); } - @FunctionalInterface - private interface ExprIdGenerator { + /** + * Intermediate value class to store the generated CelExpr for the bind macro and the macro call + * information. + */ + @AutoValue + abstract static class BindMacro { + + /** Comprehension expr for the generated cel.bind macro. */ + abstract CelExpr bindExpr(); - /** Generates an expression ID based on the provided ID. */ - long generate(long exprId); + /** + * Call expr representation that will be stored in the macro call map of the AST. This is + * typically used for the purposes of supporting unparse. + */ + abstract CelExpr bindMacro(); + + private static BindMacro of(CelExpr bindExpr, CelExpr bindMacro) { + return new AutoValue_MutableAst_BindMacro(bindExpr, bindMacro); + } } } diff --git a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel index a769bc20..f4090313 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel +++ b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel @@ -17,6 +17,7 @@ java_library( "//common/resources/testdata/proto3:test_all_types_java_proto", "//common/types", "//compiler", + "//extensions", "//extensions:optional_library", "//optimizer", "//optimizer:mutable_ast", @@ -25,6 +26,7 @@ java_library( "//optimizer:optimizer_impl", "//parser", "//parser:macro", + "//parser:operator", "//parser:unparser", "//runtime", "@maven//:com_google_guava_guava", diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java index 65d94b70..b7f1d29d 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -19,6 +19,7 @@ import com.google.common.collect.ImmutableMap; import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; @@ -27,16 +28,20 @@ import dev.cel.common.CelOverloadDecl; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelIdent; import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.ast.CelExpr.ExprKind.Kind; import dev.cel.common.navigation.CelNavigableAst; import dev.cel.common.navigation.CelNavigableExpr; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; +import dev.cel.extensions.CelExtensions; import dev.cel.extensions.CelOptionalLibrary; import dev.cel.parser.CelStandardMacro; import dev.cel.parser.CelUnparser; import dev.cel.parser.CelUnparserFactory; +import dev.cel.parser.Operator; import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; import org.junit.Test; import org.junit.runner.RunWith; @@ -48,7 +53,7 @@ public class MutableAstTest { .setStandardMacros(CelStandardMacro.STANDARD_MACROS) .setOptions(CelOptions.current().populateMacroCalls(true).build()) .addMessageTypes(TestAllTypes.getDescriptor()) - .addCompilerLibraries(CelOptionalLibrary.INSTANCE) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) .setContainer("dev.cel.testing.testdata.proto3") .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) @@ -109,6 +114,248 @@ public void mutableAst_macro_sourceMacroCallsPopulated() throws Exception { assertThat(mutatedAst.getSource().getMacroCalls()).isNotEmpty(); } + @Test + @TestParameters("{source: '[1].exists(x, x > 0)', expectedMacroCallSize: 1}") + @TestParameters( + "{source: '[1].exists(x, x > 0) && [2].exists(x, x > 0)', expectedMacroCallSize: 2}") + @TestParameters( + "{source: '[1].exists(x, [2].exists(y, x > 0 && y > x))', expectedMacroCallSize: 2}") + public void replaceSubtree_rootReplacedWithMacro_macroCallPopulated( + String source, int expectedMacroCallSize) throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("1").getAst(); + CelAbstractSyntaxTree ast2 = CEL.compile(source).getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().expr().id()); + + assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(expectedMacroCallSize); + assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo(source); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(true); + } + + @Test + public void replaceSubtree_branchReplacedWithMacro_macroCallPopulated() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("true && false").getAst(); + CelAbstractSyntaxTree ast2 = CEL.compile("[1].exists(x, x > 0)").getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree(ast, ast2, 3); // Replace false with the macro expr + CelAbstractSyntaxTree mutatedAst2 = + MutableAst.replaceSubtree(ast, ast2, 1); // Replace true with the macro expr + + assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(1); + assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("true && [1].exists(x, x > 0)"); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(true); + assertThat(mutatedAst2.getSource().getMacroCalls()).hasSize(1); + assertThat(CEL_UNPARSER.unparse(mutatedAst2)).isEqualTo("[1].exists(x, x > 0) && false"); + assertThat(CEL.createProgram(CEL.check(mutatedAst2).getAst()).eval()).isEqualTo(false); + } + + @Test + public void replaceSubtree_macroInsertedIntoExistingMacro_macroCallPopulated() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("[1].exists(x, x > 0 && true)").getAst(); + CelAbstractSyntaxTree ast2 = CEL.compile("[2].exists(y, y > 0)").getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree(ast, ast2, 9); // Replace true with the ast2 maro expr + + assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(2); + assertThat(CEL_UNPARSER.unparse(mutatedAst)) + .isEqualTo("[1].exists(x, x > 0 && [2].exists(y, y > 0))"); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(true); + } + + @Test + public void replaceSubtreeWithNewBindMacro_replaceRoot() throws Exception { + CelAbstractSyntaxTree ast = CEL.compile("1 + 1").getAst(); + String variableName = "@r0"; + CelExpr resultExpr = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.ofIdentExpr(0, variableName), CelExpr.ofIdentExpr(0, variableName)) + .build()) + .build(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtreeWithNewBindMacro( + ast, + variableName, + CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), + resultExpr, + CelNavigableAst.fromAst(ast).getRoot().expr().id()); + + assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(1); + assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("cel.bind(@r0, 3, @r0 + @r0)"); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(6); + assertConsistentMacroCalls(mutatedAst); + } + + @Test + public void replaceSubtreeWithNewBindMacro_nestedBindMacro_replaceComprehensionResult() + throws Exception { + // Arrange + CelAbstractSyntaxTree ast = CEL.compile("1 + 1").getAst(); + String variableName = "@r0"; + CelExpr resultExpr = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.ofIdentExpr(0, variableName), CelExpr.ofIdentExpr(0, variableName)) + .build()) + .build(); + + // Act + // Perform the initial replacement. (1 + 1) -> cel.bind(@r0, 3, @r0 + @r0) + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtreeWithNewBindMacro( + ast, + variableName, + CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), + resultExpr, + 2); // Replace + + String nestedVariableName = "@r1"; + // Construct a new result expression of the form @r0 + @r0 + @r1 + @r1 + resultExpr = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.ofIdentExpr(0, variableName), + CelExpr.ofIdentExpr(0, variableName)) + .build()) + .build(), + CelExpr.ofIdentExpr(0, nestedVariableName)) + .build()) + .build(), + CelExpr.ofIdentExpr(0, nestedVariableName)) + .build()) + .build(); + // Find the call node (_+_) in the comprehension's result + long exprIdToReplace = + CelNavigableAst.fromAst(mutatedAst) + .getRoot() + .children() + .filter( + node -> + node.getKind().equals(Kind.CALL) + && node.parent().get().getKind().equals(Kind.COMPREHENSION)) + .findAny() + .get() + .expr() + .id(); + // This should produce cel.bind(@r1, 1, cel.bind(@r0, 3, @r0 + @r0 + @r1 + @r1)) + mutatedAst = + MutableAst.replaceSubtreeWithNewBindMacro( + mutatedAst, + nestedVariableName, + CelExpr.ofConstantExpr(0, CelConstant.ofValue(1L)), + resultExpr, + exprIdToReplace); // Replace + + + assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(2); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(8); + assertThat(CEL_UNPARSER.unparse(mutatedAst)) + .isEqualTo("cel.bind(@r0, 3, cel.bind(@r1, 1, @r0 + @r0 + @r1 + @r1))"); + assertConsistentMacroCalls(mutatedAst); + } + + @Test + public void replaceSubtreeWithNewBindMacro_replaceRootWithNestedBindMacro() throws Exception { + // Arrange + CelAbstractSyntaxTree ast = CEL.compile("1 + 1 + 3 + 3").getAst(); + String variableName = "@r0"; + CelExpr resultExpr = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.ofIdentExpr(0, variableName), CelExpr.ofIdentExpr(0, variableName)) + .build()) + .build(); + + // Act + // Perform the initial replacement. (1 + 1 + 3 + 3) -> cel.bind(@r0, 1, @r0 + @r0) + 3 + 3 + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtreeWithNewBindMacro( + ast, + variableName, + CelExpr.ofConstantExpr(0, CelConstant.ofValue(1L)), + resultExpr, + 2); // Replace + + // Construct a new result expression of the form: + // cel.bind(@r1, 3, cel.bind(@r0, 1, @r0 + @r0) + @r1 + @r1) + String nestedVariableName = "@r1"; + CelExpr bindMacro = + CelNavigableAst.fromAst(mutatedAst) + .getRoot() + .descendants() + .filter(node -> node.getKind().equals(Kind.COMPREHENSION)) + .findAny() + .get() + .expr(); + resultExpr = + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs( + CelExpr.newBuilder() + .setCall( + CelCall.newBuilder() + .setFunction(Operator.ADD.getFunction()) + .addArgs(bindMacro, CelExpr.ofIdentExpr(0, nestedVariableName)) + .build()) + .build(), + CelExpr.ofIdentExpr(0, nestedVariableName)) + .build()) + .build(); + // Replace the root with the new result and a bind macro inserted + mutatedAst = + MutableAst.replaceSubtreeWithNewBindMacro( + mutatedAst, + nestedVariableName, + CelExpr.ofConstantExpr(0, CelConstant.ofValue(3L)), + resultExpr, + 1); + + assertThat(mutatedAst.getSource().getMacroCalls()).hasSize(2); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(8); + assertThat(CEL_UNPARSER.unparse(mutatedAst)) + .isEqualTo("cel.bind(@r1, 3, cel.bind(@r0, 1, @r0 + @r0) + @r1 + @r1)"); + assertConsistentMacroCalls(mutatedAst); + } + + @Test + public void replaceSubtree_macroReplacedWithConstExpr_macroCallCleared() throws Exception { + CelAbstractSyntaxTree ast = + CEL.compile("[1].exists(x, x > 0) && [2].exists(x, x > 0)").getAst(); + CelAbstractSyntaxTree ast2 = CEL.compile("1").getAst(); + + CelAbstractSyntaxTree mutatedAst = + MutableAst.replaceSubtree(ast, ast2, CelNavigableAst.fromAst(ast).getRoot().expr().id()); + + assertThat(mutatedAst.getSource().getMacroCalls()).isEmpty(); + assertThat(CEL_UNPARSER.unparse(mutatedAst)).isEqualTo("1"); + assertThat(CEL.createProgram(CEL.check(mutatedAst).getAst()).eval()).isEqualTo(1); + } + @Test public void globalCallExpr_replaceRoot() throws Exception { // Tree shape (brackets are expr IDs): @@ -371,8 +618,8 @@ public void comprehension_replaceIterRange() throws Exception { ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(false)).build(), 2); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, i)"); - assertConsistentMacroCalls(ast); assertThat(CEL.createProgram(CEL.check(replacedAst).getAst()).eval()).isEqualTo(false); + assertConsistentMacroCalls(ast); } @Test @@ -384,8 +631,11 @@ public void comprehension_replaceAccuInit() throws Exception { ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 6); assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[false].exists(i, i)"); - assertConsistentMacroCalls(ast); assertThat(CEL.createProgram(CEL.check(replacedAst).getAst()).eval()).isEqualTo(true); + // Check that the init value of accumulator has actually been replaced. + assertThat(ast.getExpr().comprehension().accuInit().constant().booleanValue()).isFalse(); + assertThat(replacedAst.getExpr().comprehension().accuInit().constant().booleanValue()).isTrue(); + assertConsistentMacroCalls(ast); } @Test @@ -402,18 +652,6 @@ public void comprehension_replaceLoopStep() throws Exception { assertConsistentMacroCalls(ast); } - @Test - public void comprehension_astContainsDuplicateNodes() throws Exception { - CelAbstractSyntaxTree ast = CEL.compile("[{\"a\": 1}].map(i, i)").getAst(); - - // AST contains two duplicate expr (ID: 9). Just ensure that it doesn't throw. - CelAbstractSyntaxTree replacedAst = - MutableAst.replaceSubtree(ast, CelExpr.newBuilder().build(), -1); - - assertThat(CEL_UNPARSER.unparse(replacedAst)).isEqualTo("[{\"a\": 1}].map(i, i)"); - assertConsistentMacroCalls(ast); - } - /** * Asserts that the expressions that appears in source_info's macro calls are consistent with the * actual expr nodes in the AST.