diff --git a/common/src/main/java/dev/cel/common/navigation/BUILD.bazel b/common/src/main/java/dev/cel/common/navigation/BUILD.bazel index 07604c5f..5a3281e7 100644 --- a/common/src/main/java/dev/cel/common/navigation/BUILD.bazel +++ b/common/src/main/java/dev/cel/common/navigation/BUILD.bazel @@ -13,6 +13,7 @@ java_library( "CelNavigableAst.java", "CelNavigableExpr.java", "CelNavigableExprVisitor.java", + "ExprHeightCalculator.java", ], tags = [ ], diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java index 249a3126..d0cd2ba7 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExpr.java @@ -15,8 +15,6 @@ package dev.cel.common.navigation; import com.google.auto.value.AutoValue; -import com.google.errorprone.annotations.CanIgnoreReturnValue; -import com.google.errorprone.annotations.CheckReturnValue; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelComprehension; import dev.cel.common.ast.CelExpr.ExprKind; @@ -69,7 +67,12 @@ public long id() { /** Constructs a new instance of {@link CelNavigableExpr} from {@link CelExpr}. */ public static CelNavigableExpr fromExpr(CelExpr expr) { - return CelNavigableExpr.builder().setExpr(expr).build(); + ExprHeightCalculator exprHeightCalculator = new ExprHeightCalculator(expr); + + return CelNavigableExpr.builder() + .setExpr(expr) + .setHeight(exprHeightCalculator.getHeight(expr.id())) + .build(); } /** @@ -136,8 +139,6 @@ public static Builder builder() { @AutoValue.Builder public abstract static class Builder { - private Builder parentBuilder; - public abstract CelExpr expr(); public abstract int depth(); @@ -150,25 +151,11 @@ public ExprKind.Kind getKind() { abstract Builder setParent(CelNavigableExpr value); - @CanIgnoreReturnValue - public Builder setParentBuilder(CelNavigableExpr.Builder value) { - parentBuilder = value; - return this; - } - public abstract Builder setDepth(int value); public abstract Builder setHeight(int value); - public abstract CelNavigableExpr autoBuild(); - - @CheckReturnValue - public CelNavigableExpr build() { - if (parentBuilder != null) { - setParent(parentBuilder.build()); - } - return autoBuild(); - } + public abstract CelNavigableExpr build(); } public abstract Builder toBuilder(); diff --git a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java index cc49ffce..afafd89c 100644 --- a/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java +++ b/common/src/main/java/dev/cel/common/navigation/CelNavigableExprVisitor.java @@ -14,8 +14,6 @@ package dev.cel.common.navigation; -import static java.lang.Math.max; - import com.google.common.collect.ImmutableList; import dev.cel.common.ast.CelExpr; import dev.cel.common.ast.CelExpr.CelCall; @@ -31,12 +29,15 @@ final class CelNavigableExprVisitor { private static final int MAX_DESCENDANTS_RECURSION_DEPTH = 500; - private final Stream.Builder streamBuilder; + private final Stream.Builder streamBuilder; + private final ExprHeightCalculator exprHeightCalculator; private final TraversalOrder traversalOrder; private final int maxDepth; - private CelNavigableExprVisitor(int maxDepth, TraversalOrder traversalOrder) { + private CelNavigableExprVisitor( + int maxDepth, ExprHeightCalculator exprHeightCalculator, TraversalOrder traversalOrder) { this.maxDepth = maxDepth; + this.exprHeightCalculator = exprHeightCalculator; this.traversalOrder = traversalOrder; this.streamBuilder = Stream.builder(); } @@ -84,14 +85,16 @@ static Stream collect( */ static Stream collect( CelNavigableExpr navigableExpr, int maxDepth, TraversalOrder traversalOrder) { - CelNavigableExprVisitor visitor = new CelNavigableExprVisitor(maxDepth, traversalOrder); + ExprHeightCalculator exprHeightCalculator = new ExprHeightCalculator(navigableExpr.expr()); + CelNavigableExprVisitor visitor = + new CelNavigableExprVisitor(maxDepth, exprHeightCalculator, traversalOrder); - visitor.visit(navigableExpr.toBuilder()); + visitor.visit(navigableExpr); - return visitor.streamBuilder.build().map(CelNavigableExpr.Builder::build); + return visitor.streamBuilder.build(); } - private int visit(CelNavigableExpr.Builder navigableExpr) { + private void visit(CelNavigableExpr navigableExpr) { if (navigableExpr.depth() > MAX_DESCENDANTS_RECURSION_DEPTH - 1) { throw new IllegalStateException("Max recursion depth reached."); } @@ -101,108 +104,89 @@ private int visit(CelNavigableExpr.Builder navigableExpr) { streamBuilder.add(navigableExpr); } - int height = 1; switch (navigableExpr.getKind()) { case CALL: - height += visit(navigableExpr, navigableExpr.expr().call()); + visit(navigableExpr, navigableExpr.expr().call()); break; case CREATE_LIST: - height += visit(navigableExpr, navigableExpr.expr().createList()); + visit(navigableExpr, navigableExpr.expr().createList()); break; case SELECT: - height += visit(navigableExpr, navigableExpr.expr().select()); + visit(navigableExpr, navigableExpr.expr().select()); break; case CREATE_STRUCT: - height += visitStruct(navigableExpr, navigableExpr.expr().createStruct()); + visitStruct(navigableExpr, navigableExpr.expr().createStruct()); break; case CREATE_MAP: - height += visitMap(navigableExpr, navigableExpr.expr().createMap()); + visitMap(navigableExpr, navigableExpr.expr().createMap()); break; case COMPREHENSION: - height += visit(navigableExpr, navigableExpr.expr().comprehension()); + visit(navigableExpr, navigableExpr.expr().comprehension()); break; default: - // This is a leaf node - height = 0; break; } - navigableExpr.setHeight(height); if (addToStream && traversalOrder.equals(TraversalOrder.POST_ORDER)) { streamBuilder.add(navigableExpr); } - - return height; } - private int visit(CelNavigableExpr.Builder navigableExpr, CelCall call) { - int targetHeight = 0; + private void visit(CelNavigableExpr navigableExpr, CelCall call) { if (call.target().isPresent()) { - CelNavigableExpr.Builder target = newNavigableChild(navigableExpr, call.target().get()); - targetHeight = visit(target); + visit(newNavigableChild(navigableExpr, call.target().get())); } - int argumentHeight = visitExprList(call.args(), navigableExpr); - return max(targetHeight, argumentHeight); + visitExprList(call.args(), navigableExpr); } - private int visit(CelNavigableExpr.Builder navigableExpr, CelCreateList createList) { - return visitExprList(createList.elements(), navigableExpr); + private void visit(CelNavigableExpr navigableExpr, CelCreateList createList) { + visitExprList(createList.elements(), navigableExpr); } - private int visit(CelNavigableExpr.Builder navigableExpr, CelSelect selectExpr) { - CelNavigableExpr.Builder operand = newNavigableChild(navigableExpr, selectExpr.operand()); - return visit(operand); + private void visit(CelNavigableExpr navigableExpr, CelSelect selectExpr) { + CelNavigableExpr operand = newNavigableChild(navigableExpr, selectExpr.operand()); + visit(operand); } - private int visit(CelNavigableExpr.Builder navigableExpr, CelComprehension comprehension) { - int maxHeight = 0; - maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.iterRange())), maxHeight); - maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.accuInit())), maxHeight); - maxHeight = - max(visit(newNavigableChild(navigableExpr, comprehension.loopCondition())), maxHeight); - maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.loopStep())), maxHeight); - maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.result())), maxHeight); - - return maxHeight; + private void visit(CelNavigableExpr navigableExpr, CelComprehension comprehension) { + visit(newNavigableChild(navigableExpr, comprehension.iterRange())); + visit(newNavigableChild(navigableExpr, comprehension.accuInit())); + visit(newNavigableChild(navigableExpr, comprehension.loopCondition())); + visit(newNavigableChild(navigableExpr, comprehension.loopStep())); + visit(newNavigableChild(navigableExpr, comprehension.result())); } - private int visitStruct(CelNavigableExpr.Builder navigableExpr, CelCreateStruct struct) { - int maxHeight = 0; + private void visitStruct(CelNavigableExpr navigableExpr, CelCreateStruct struct) { for (CelCreateStruct.Entry entry : struct.entries()) { - CelNavigableExpr.Builder value = newNavigableChild(navigableExpr, entry.value()); - maxHeight = max(visit(value), maxHeight); + visit(newNavigableChild(navigableExpr, entry.value())); } - return maxHeight; } - private int visitMap(CelNavigableExpr.Builder navigableExpr, CelCreateMap map) { - int maxHeight = 0; + private void visitMap(CelNavigableExpr navigableExpr, CelCreateMap map) { for (CelCreateMap.Entry entry : map.entries()) { - CelNavigableExpr.Builder key = newNavigableChild(navigableExpr, entry.key()); - maxHeight = max(visit(key), maxHeight); + CelNavigableExpr key = newNavigableChild(navigableExpr, entry.key()); + visit(key); - CelNavigableExpr.Builder value = newNavigableChild(navigableExpr, entry.value()); - maxHeight = max(visit(value), maxHeight); + CelNavigableExpr value = newNavigableChild(navigableExpr, entry.value()); + visit(value); } - return 0; } - private int visitExprList( - ImmutableList createListExpr, CelNavigableExpr.Builder parent) { - int maxHeight = 0; + private void visitExprList(ImmutableList createListExpr, CelNavigableExpr parent) { for (CelExpr expr : createListExpr) { - CelNavigableExpr.Builder arg = newNavigableChild(parent, expr); - maxHeight = max(visit(arg), maxHeight); + visit(newNavigableChild(parent, expr)); } - return maxHeight; } - private CelNavigableExpr.Builder newNavigableChild( - CelNavigableExpr.Builder parent, CelExpr expr) { - return CelNavigableExpr.builder() - .setExpr(expr) - .setDepth(parent.depth() + 1) - .setParentBuilder(parent); + private CelNavigableExpr newNavigableChild(CelNavigableExpr parent, CelExpr expr) { + CelNavigableExpr.Builder navigableExpr = + CelNavigableExpr.builder() + .setExpr(expr) + .setDepth(parent.depth() + 1) + .setHeight(exprHeightCalculator.getHeight(expr.id())) + .setParent(parent); + + return navigableExpr.build(); } } diff --git a/common/src/main/java/dev/cel/common/navigation/ExprHeightCalculator.java b/common/src/main/java/dev/cel/common/navigation/ExprHeightCalculator.java new file mode 100644 index 00000000..f6555a09 --- /dev/null +++ b/common/src/main/java/dev/cel/common/navigation/ExprHeightCalculator.java @@ -0,0 +1,132 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.common.navigation; + +import static java.lang.Math.max; + +import com.google.common.collect.ImmutableList; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelComprehension; +import dev.cel.common.ast.CelExpr.CelCreateList; +import dev.cel.common.ast.CelExpr.CelCreateMap; +import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.CelSelect; +import java.util.HashMap; + +/** Package-private class to assist computing the height of expression nodes. */ +final class ExprHeightCalculator { + // Store hashmap instead of immutable map for performance, such that this helper class can be + // instantiated faster. + private final HashMap idToHeight; + + ExprHeightCalculator(CelExpr celExpr) { + this.idToHeight = new HashMap<>(); + visit(celExpr); + } + + int getHeight(Long exprId) { + if (!idToHeight.containsKey(exprId)) { + throw new IllegalStateException("Height not found for expression id: " + exprId); + } + + return idToHeight.get(exprId); + } + + private int visit(CelExpr celExpr) { + int height = 1; + switch (celExpr.exprKind().getKind()) { + case CALL: + height += visit(celExpr.call()); + break; + case CREATE_LIST: + height += visit(celExpr.createList()); + break; + case SELECT: + height += visit(celExpr.select()); + break; + case CREATE_STRUCT: + height += visitStruct(celExpr.createStruct()); + break; + case CREATE_MAP: + height += visitMap(celExpr.createMap()); + break; + case COMPREHENSION: + height += visit(celExpr.comprehension()); + break; + default: + // This is a leaf node + height = 0; + break; + } + + idToHeight.put(celExpr.id(), height); + return height; + } + + private int visit(CelCall call) { + int targetHeight = 0; + if (call.target().isPresent()) { + targetHeight = visit(call.target().get()); + } + + int argumentHeight = visitExprList(call.args()); + return max(targetHeight, argumentHeight); + } + + private int visit(CelCreateList createList) { + return visitExprList(createList.elements()); + } + + private int visit(CelSelect selectExpr) { + return visit(selectExpr.operand()); + } + + private int visit(CelComprehension comprehension) { + int maxHeight = 0; + maxHeight = max(visit(comprehension.iterRange()), maxHeight); + maxHeight = max(visit(comprehension.accuInit()), maxHeight); + maxHeight = max(visit(comprehension.loopCondition()), maxHeight); + maxHeight = max(visit(comprehension.loopStep()), maxHeight); + maxHeight = max(visit(comprehension.result()), maxHeight); + + return maxHeight; + } + + private int visitStruct(CelCreateStruct struct) { + int maxHeight = 0; + for (CelCreateStruct.Entry entry : struct.entries()) { + maxHeight = max(visit(entry.value()), maxHeight); + } + return maxHeight; + } + + private int visitMap(CelCreateMap map) { + int maxHeight = 0; + for (CelCreateMap.Entry entry : map.entries()) { + maxHeight = max(visit(entry.key()), maxHeight); + maxHeight = max(visit(entry.value()), maxHeight); + } + return maxHeight; + } + + private int visitExprList(ImmutableList createListExpr) { + int maxHeight = 0; + for (CelExpr expr : createListExpr) { + maxHeight = max(visit(expr), maxHeight); + } + return maxHeight; + } +} diff --git a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java index e263020f..710267e5 100644 --- a/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java +++ b/common/src/test/java/dev/cel/common/navigation/CelNavigableExprVisitorTest.java @@ -153,22 +153,28 @@ public void add_fromLeaf_heightSetForParents() throws Exception { CelCompiler compiler = CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); // Tree shape: - // + - // + 2 - // 1 a - CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2").getAst(); + // + + // + 3 + // + 2 + // a 1 + CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2 + 3").getAst(); CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); - CelNavigableExpr oneConst = + ImmutableList.Builder heights = ImmutableList.builder(); + CelNavigableExpr navigableExpr = navigableAst .getRoot() - .descendants() - .filter(node -> node.expr().constantOrDefault().int64Value() == 1) + .allNodes() + .filter(node -> node.expr().identOrDefault().name().equals("a")) .findAny() .get(); - assertThat(oneConst.height()).isEqualTo(0); // 1 - assertThat(oneConst.parent().get().height()).isEqualTo(1); // + - assertThat(oneConst.parent().get().parent().get().height()).isEqualTo(2); // root + heights.add(navigableExpr.height()); + while (navigableExpr.parent().isPresent()) { + navigableExpr = navigableExpr.parent().get(); + heights.add(navigableExpr.height()); + } + + assertThat(heights.build()).containsExactly(3, 2, 1, 0); } @Test @@ -178,9 +184,9 @@ public void add_children_heightSet(@TestParameter TraversalOrder traversalOrder) CelCompilerFactory.standardCelCompilerBuilder().addVar("a", SimpleType.INT).build(); // Tree shape: // + - // + 2 - // + a - // 3 + // + 3 + // + 2 + // a 1 CelAbstractSyntaxTree ast = compiler.compile("1 + a + 2 + 3").getAst(); CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); @@ -840,7 +846,6 @@ public void comprehension_allNodes_parentsPopulated() throws Exception { ImmutableList allNodes = navigableAst.getRoot().allNodes(TraversalOrder.PRE_ORDER).collect(toImmutableList()); - CelExpr iterRangeConstExpr = CelExpr.ofConstantExpr(2, CelConstant.ofValue(true)); CelExpr iterRange = CelExpr.ofCreateListExpr(1, ImmutableList.of(iterRangeConstExpr), ImmutableList.of()); diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java index 57a7485f..bb677826 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java @@ -717,9 +717,9 @@ private enum CseTestCase { "cel.@block([{\"key\": \"test\"}], {?\"key\":" + " optional.of(\"test\")}[?\"bogus\"].or(@index0[?\"bogus\"]).orValue(@index0[\"key\"])" + " == \"test\")", - "cel.@block([{\"key\": \"test\"}, @index0[\"key\"], @index0[?\"bogus\"], {?\"key\":" - + " optional.of(\"test\")}, @index3[?\"bogus\"], @index4.or(@index2)," - + " @index5.orValue(@index1)], @index6 == \"test\")"), + "cel.@block([{\"key\": \"test\"}, @index0[\"key\"], @index0[?\"bogus\"]," + + " optional.of(\"test\"), {?\"key\": @index3}, @index4[?\"bogus\"]," + + " @index5.or(@index2), @index6.orValue(@index1)], @index7 == \"test\")"), OPTIONAL_MESSAGE( "TestAllTypes{?single_int64: optional.ofNonZeroValue(1), ?single_int32:" + " optional.of(4)}.single_int32 + TestAllTypes{?single_int64:"