From f98b582ab7bb96dd468f452d12dd59f68ccc06fa Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 13 Feb 2024 09:56:44 -0800 Subject: [PATCH] Add height property to CelNavigableExpr PiperOrigin-RevId: 606657258 --- .../dev/cel/common/navigation/BUILD.bazel | 1 + .../common/navigation/CelNavigableExpr.java | 45 +++- .../navigation/CelNavigableExprVisitor.java | 101 +++++---- .../CelNavigableExprVisitorTest.java | 206 ++++++++++++++++++ 4 files changed, 308 insertions(+), 45 deletions(-) diff --git a/common/src/main/java/dev/cel/common/navigation/BUILD.bazel b/common/src/main/java/dev/cel/common/navigation/BUILD.bazel index a9ca1c4e..07604c5f 100644 --- a/common/src/main/java/dev/cel/common/navigation/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/navigation/BUILD.bazel @@ -20,6 +20,7 @@ java_library( "//:auto_value", "//common", "//common/ast", + "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], ) diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java index 326f8b78..249a3126 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java @@ -15,6 +15,8 @@ package dev.cel.common.navigation; import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelComprehension; import dev.cel.common.ast.CelExpr.ExprKind; @@ -59,6 +61,12 @@ public long id() { /** Represents the count of transitive parents. Depth of an AST's root is 0. */ public abstract int depth(); + /** + * Represents the maximum count of children from any of its branches. Height of a leaf node is 0. + * For example, the height of the call node 'func' in expression `(1 + 2 + 3).func(4 + 5)` is 3. + */ + public abstract int height(); + /** Constructs a new instance of {@link CelNavigableExpr} from {@link CelExpr}. */ public static CelNavigableExpr fromExpr(CelExpr expr) { return CelNavigableExpr.builder().setExpr(expr).build(); @@ -93,7 +101,8 @@ public Stream descendants() { * the specified traversal order. */ public Stream descendants(TraversalOrder traversalOrder) { - return CelNavigableExprVisitor.collect(this, traversalOrder).filter(node -> !node.equals(this)); + return CelNavigableExprVisitor.collect(this, traversalOrder) + .filter(node -> node.depth() > this.depth()); } /** @@ -110,7 +119,7 @@ public Stream children() { */ public Stream children(TraversalOrder traversalOrder) { return CelNavigableExprVisitor.collect(this, this.depth() + 1, traversalOrder) - .filter(node -> !node.equals(this)); + .filter(node -> node.depth() > this.depth()); } /** Returns the underlying kind of the {@link CelExpr}. */ @@ -120,21 +129,47 @@ public ExprKind.Kind getKind() { /** Create a new builder to construct a {@link CelNavigableExpr} instance. */ public static Builder builder() { - return new AutoValue_CelNavigableExpr.Builder().setDepth(0); + return new AutoValue_CelNavigableExpr.Builder().setDepth(0).setHeight(0); } /** Builder to configure {@link CelNavigableExpr}. */ @AutoValue.Builder public abstract static class Builder { + private Builder parentBuilder; + public abstract CelExpr expr(); + public abstract int depth(); + + public ExprKind.Kind getKind() { + return expr().exprKind().getKind(); + } + public abstract Builder setExpr(CelExpr value); - public abstract Builder setParent(CelNavigableExpr value); + abstract Builder setParent(CelNavigableExpr value); + + @CanIgnoreReturnValue + public Builder setParentBuilder(CelNavigableExpr.Builder value) { + parentBuilder = value; + return this; + } public abstract Builder setDepth(int value); - public abstract CelNavigableExpr build(); + public abstract Builder setHeight(int value); + + public abstract CelNavigableExpr autoBuild(); + + @CheckReturnValue + public CelNavigableExpr build() { + if (parentBuilder != null) { + setParent(parentBuilder.build()); + } + return autoBuild(); + } } + + public abstract Builder toBuilder(); } diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java index 8abcf017..59cde749 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java @@ -14,6 +14,8 @@ package dev.cel.common.navigation; +import static java.lang.Math.max; + import com.google.common.collect.ImmutableList; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; @@ -29,7 +31,7 @@ final class CelNavigableExprVisitor { private static final int MAX_DESCENDANTS_RECURSION_DEPTH = 500; - private final Stream.Builder streamBuilder; + private final Stream.Builder streamBuilder; private final TraversalOrder traversalOrder; private final int maxDepth; @@ -84,105 +86,124 @@ static Stream collect( CelNavigableExpr navigableExpr, int maxDepth, TraversalOrder traversalOrder) { CelNavigableExprVisitor visitor = new CelNavigableExprVisitor(maxDepth, traversalOrder); - visitor.visit(navigableExpr); + visitor.visit(navigableExpr.toBuilder()); - return visitor.streamBuilder.build(); + return visitor.streamBuilder.build().map(CelNavigableExpr.Builder::build); } - private void visit(CelNavigableExpr navigableExpr) { + private int visit(CelNavigableExpr.Builder navigableExpr) { if (navigableExpr.depth() > MAX_DESCENDANTS_RECURSION_DEPTH - 1) { throw new IllegalStateException("Max recursion depth reached."); } if (navigableExpr.depth() > maxDepth) { - return; + return -1; } if (traversalOrder.equals(TraversalOrder.PRE_ORDER)) { streamBuilder.add(navigableExpr); } + int height = 1; switch (navigableExpr.getKind()) { case CALL: - visit(navigableExpr, navigableExpr.expr().call()); + height += visit(navigableExpr, navigableExpr.expr().call()); break; case CREATE_LIST: - visit(navigableExpr, navigableExpr.expr().createList()); + height += visit(navigableExpr, navigableExpr.expr().createList()); break; case SELECT: - visit(navigableExpr, navigableExpr.expr().select()); + height += visit(navigableExpr, navigableExpr.expr().select()); break; case CREATE_STRUCT: - visitStruct(navigableExpr, navigableExpr.expr().createStruct()); + height += visitStruct(navigableExpr, navigableExpr.expr().createStruct()); break; case CREATE_MAP: - visitMap(navigableExpr, navigableExpr.expr().createMap()); + height += visitMap(navigableExpr, navigableExpr.expr().createMap()); break; case COMPREHENSION: - visit(navigableExpr, navigableExpr.expr().comprehension()); + height += visit(navigableExpr, navigableExpr.expr().comprehension()); break; default: + // This is a leaf node + height = 0; break; } + navigableExpr.setHeight(height); if (traversalOrder.equals(TraversalOrder.POST_ORDER)) { streamBuilder.add(navigableExpr); } + + return height; } - private void visit(CelNavigableExpr navigableExpr, CelCall call) { + private int visit(CelNavigableExpr.Builder navigableExpr, CelCall call) { + int targetHeight = 0; if (call.target().isPresent()) { - CelNavigableExpr target = newNavigableChild(navigableExpr, call.target().get()); - visit(target); + CelNavigableExpr.Builder target = newNavigableChild(navigableExpr, call.target().get()); + targetHeight = visit(target); } - visitExprList(call.args(), navigableExpr); + int argumentHeight = visitExprList(call.args(), navigableExpr); + return max(targetHeight, argumentHeight); } - private void visit(CelNavigableExpr navigableExpr, CelCreateList createList) { - visitExprList(createList.elements(), navigableExpr); + private int visit(CelNavigableExpr.Builder navigableExpr, CelCreateList createList) { + return visitExprList(createList.elements(), navigableExpr); } - private void visit(CelNavigableExpr navigableExpr, CelSelect selectExpr) { - CelNavigableExpr operand = newNavigableChild(navigableExpr, selectExpr.operand()); - visit(operand); + private int visit(CelNavigableExpr.Builder navigableExpr, CelSelect selectExpr) { + CelNavigableExpr.Builder operand = newNavigableChild(navigableExpr, selectExpr.operand()); + return visit(operand); } - private void visit(CelNavigableExpr navigableExpr, CelComprehension comprehension) { - visit(newNavigableChild(navigableExpr, comprehension.iterRange())); - visit(newNavigableChild(navigableExpr, comprehension.accuInit())); - visit(newNavigableChild(navigableExpr, comprehension.loopCondition())); - visit(newNavigableChild(navigableExpr, comprehension.loopStep())); - visit(newNavigableChild(navigableExpr, comprehension.result())); + private int visit(CelNavigableExpr.Builder navigableExpr, CelComprehension comprehension) { + int maxHeight = 0; + maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.iterRange())), maxHeight); + maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.accuInit())), maxHeight); + maxHeight = + max(visit(newNavigableChild(navigableExpr, comprehension.loopCondition())), maxHeight); + maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.loopStep())), maxHeight); + maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.result())), maxHeight); + + return maxHeight; } - private void visitStruct(CelNavigableExpr navigableExpr, CelCreateStruct struct) { + private int visitStruct(CelNavigableExpr.Builder navigableExpr, CelCreateStruct struct) { + int maxHeight = 0; for (CelCreateStruct.Entry entry : struct.entries()) { - CelNavigableExpr value = newNavigableChild(navigableExpr, entry.value()); - visit(value); + CelNavigableExpr.Builder value = newNavigableChild(navigableExpr, entry.value()); + maxHeight = max(visit(value), maxHeight); } + return maxHeight; } - private void visitMap(CelNavigableExpr navigableExpr, CelCreateMap map) { + private int visitMap(CelNavigableExpr.Builder navigableExpr, CelCreateMap map) { + int maxHeight = 0; for (CelCreateMap.Entry entry : map.entries()) { - CelNavigableExpr key = newNavigableChild(navigableExpr, entry.key()); - visit(key); + CelNavigableExpr.Builder key = newNavigableChild(navigableExpr, entry.key()); + maxHeight = max(visit(key), maxHeight); - CelNavigableExpr value = newNavigableChild(navigableExpr, entry.value()); - visit(value); + CelNavigableExpr.Builder value = newNavigableChild(navigableExpr, entry.value()); + maxHeight = max(visit(value), maxHeight); } + return 0; } - private void visitExprList(ImmutableList createListExpr, CelNavigableExpr parent) { + private int visitExprList( + ImmutableList createListExpr, CelNavigableExpr.Builder parent) { + int maxHeight = 0; for (CelExpr expr : createListExpr) { - CelNavigableExpr arg = newNavigableChild(parent, expr); - visit(arg); + CelNavigableExpr.Builder arg = newNavigableChild(parent, expr); + maxHeight = max(visit(arg), maxHeight); } + return maxHeight; } - private CelNavigableExpr newNavigableChild(CelNavigableExpr parent, CelExpr expr) { + private CelNavigableExpr.Builder newNavigableChild( + CelNavigableExpr.Builder parent, CelExpr expr) { return CelNavigableExpr.builder() .setExpr(expr) .setDepth(parent.depth() + 1) - .setParent(parent) - .build(); + .setParentBuilder(parent); } } diff --git a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java index f3739ca5..3501aea6 100644 --- a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java +++ b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java @@ -107,6 +107,46 @@ public void add_allNodes_allNodesReturned() throws Exception { CelExpr.ofConstantExpr(5, CelConstant.ofValue(2))); } + @Test + public void add_preOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + // Tree shape: + // + + // + 2 + // 1 a + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodeHeights = + navigableAst + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + assertThat(allNodeHeights).containsExactly(2, 1, 0, 0, 0).inOrder(); // +, +, 1, a, 2 + } + + @Test + public void add_postOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + // Tree shape: + // + + // + 2 + // 1 a + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodeHeights = + navigableAst + .getRoot() + .allNodes(TraversalOrder.POST_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + assertThat(allNodeHeights).containsExactly(0, 0, 1, 0, 2).inOrder(); // 1, a, +, 2, + + } + @Test public void add_filterConstants_allNodesReturned() throws Exception { CelCompiler compiler = @@ -450,6 +490,46 @@ public void messageConstruction_filterCreateStruct_allNodesReturned() throws Exc false)))); } + @Test + public void messageConstruction_preOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .build(); + CelAbstractSyntaxTree ast = compiler.compile("TestAllTypes{single_int64: 1}").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(1, 0).inOrder(); + } + + @Test + public void messageConstruction_postOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .build(); + CelAbstractSyntaxTree ast = compiler.compile("TestAllTypes{single_int64: 1}").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.POST_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(0, 1).inOrder(); + } + @Test public void mapConstruction_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().build(); @@ -496,6 +576,38 @@ public void mapConstruction_filterCreateMap_allNodesReturned() throws Exception CelExpr.ofCreateMapEntryExpr(2, mapKeyExpr, mapValueExpr, false)))); } + @Test + public void mapConstruction_preOrder_heightSet() throws Exception { + CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().build(); + CelAbstractSyntaxTree ast = compiler.compile("{'key': 2}").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(1, 0, 0).inOrder(); + } + + @Test + public void mapConstruction_postOrder_heightSet() throws Exception { + CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().build(); + CelAbstractSyntaxTree ast = compiler.compile("{'key': 2}").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.POST_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(0, 0, 1).inOrder(); + } + @Test public void emptyMapConstruction_allNodesReturned() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().build(); @@ -633,6 +745,44 @@ public void comprehension_postOrder_allNodesReturned() throws Exception { .inOrder(); } + @Test + public void comprehension_preOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setStandardMacros(CelStandardMacro.EXISTS) + .build(); + CelAbstractSyntaxTree ast = compiler.compile("[true].exists(i, i)").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(3, 1, 0, 0, 2, 1, 0, 1, 0, 0, 0).inOrder(); + } + + @Test + public void comprehension_postOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .setStandardMacros(CelStandardMacro.EXISTS) + .build(); + CelAbstractSyntaxTree ast = compiler.compile("[true].exists(i, i)").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.POST_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(0, 1, 0, 0, 1, 2, 0, 0, 1, 0, 3).inOrder(); + } + @Test public void comprehension_allNodes_parentsPopulated() throws Exception { CelCompiler compiler = @@ -816,6 +966,62 @@ public void callExpr_postOrder() throws Exception { .inOrder(); } + @Test + public void callExpr_preOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addFunctionDeclarations( + newFunctionDeclaration( + "test", + newMemberOverload( + "test_overload", + SimpleType.STRING, + SimpleType.STRING, + SimpleType.INT, + SimpleType.UINT))) + .build(); + CelAbstractSyntaxTree ast = + compiler.compile("('a' + 'b' + 'c' + 'd').test((1 + 2 + 3), 6u)").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.PRE_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(4, 3, 2, 1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0).inOrder(); + } + + @Test + public void callExpr_postOrder_heightSet() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder() + .addFunctionDeclarations( + newFunctionDeclaration( + "test", + newMemberOverload( + "test_overload", + SimpleType.STRING, + SimpleType.STRING, + SimpleType.INT, + SimpleType.UINT))) + .build(); + CelAbstractSyntaxTree ast = + compiler.compile("('a' + 'b' + 'c' + 'd').test((1 + 2 + 3), 6u)").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodes = + navigableAst + .getRoot() + .allNodes(TraversalOrder.POST_ORDER) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + + assertThat(allNodes).containsExactly(0, 0, 1, 0, 2, 0, 3, 0, 0, 1, 0, 2, 0, 4).inOrder(); + } + @Test public void maxRecursionLimitReached_throws() throws Exception { StringBuilder sb = new StringBuilder();