diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java index 59cde749..cc49ffce 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java @@ -95,10 +95,9 @@ private int visit(CelNavigableExpr.Builder navigableExpr) { if (navigableExpr.depth() > MAX_DESCENDANTS_RECURSION_DEPTH - 1) { throw new IllegalStateException("Max recursion depth reached."); } - if (navigableExpr.depth() > maxDepth) { - return -1; - } - if (traversalOrder.equals(TraversalOrder.PRE_ORDER)) { + + boolean addToStream = navigableExpr.depth() <= maxDepth; + if (addToStream && traversalOrder.equals(TraversalOrder.PRE_ORDER)) { streamBuilder.add(navigableExpr); } @@ -129,7 +128,7 @@ private int visit(CelNavigableExpr.Builder navigableExpr) { } navigableExpr.setHeight(height); - if (traversalOrder.equals(TraversalOrder.POST_ORDER)) { + if (addToStream && traversalOrder.equals(TraversalOrder.POST_ORDER)) { streamBuilder.add(navigableExpr); } diff --git a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java index 3501aea6..e263020f 100644 --- a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java +++ b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.UnsignedLong; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.common.CelAbstractSyntaxTree; @@ -147,6 +148,51 @@ public void add_postOrder_heightSet() throws Exception { assertThat(allNodeHeights).containsExactly(0, 0, 1, 0, 2).inOrder(); // 1, a, +, 2, + } + @Test + public void add_fromLeaf_heightSetForParents() throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + // Tree shape: + // + + // + 2 + // 1 a + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + CelNavigableExpr oneConst = + navigableAst + .getRoot() + .descendants() + .filter(node -> node.expr().constantOrDefault().int64Value() == 1) + .findAny() + .get(); + assertThat(oneConst.height()).isEqualTo(0); // 1 + assertThat(oneConst.parent().get().height()).isEqualTo(1); // + + assertThat(oneConst.parent().get().parent().get().height()).isEqualTo(2); // root + } + + @Test + public void add_children_heightSet(@TestParameter TraversalOrder traversalOrder) + throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + // Tree shape: + // + + // + 2 + // + a + // 3 + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2 + 3").getAst(); + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodeHeights = + navigableAst + .getRoot() + .children(traversalOrder) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + assertThat(allNodeHeights).containsExactly(2, 0).inOrder(); // + (2), 2 (0) regardless of order + } + @Test public void add_filterConstants_allNodesReturned() throws Exception { CelCompiler compiler = @@ -1022,6 +1068,24 @@ public void callExpr_postOrder_heightSet() throws Exception { assertThat(allNodes).containsExactly(0, 0, 1, 0, 2, 0, 3, 0, 0, 1, 0, 2, 0, 4).inOrder(); } + @Test + public void createList_children_heightSet(@TestParameter TraversalOrder traversalOrder) + throws Exception { + CelCompiler compiler = + CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); + CelAbstractSyntaxTree ast = compiler.compile("[1, a, (2 + 2), (3 + 4 + 5)]").getAst(); + + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + + ImmutableList allNodeHeights = + navigableAst + .getRoot() + .children(traversalOrder) + .map(CelNavigableExpr::height) + .collect(toImmutableList()); + assertThat(allNodeHeights).containsExactly(0, 0, 1, 2).inOrder(); + } + @Test public void maxRecursionLimitReached_throws() throws Exception { StringBuilder sb = new StringBuilder();