Skip to content

Commit

Permalink
Compute navigable expr heights in a separate pass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609416184
  • Loading branch information
l46kok authored and copybara-github committed Feb 22, 2024
1 parent 5ebf44e commit 4f8f455
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 102 deletions.
1 change: 1 addition & 0 deletions common/src/main/java/dev/cel/common/navigation/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ java_library(
"CelNavigableAst.java",
"CelNavigableExpr.java",
"CelNavigableExprVisitor.java",
"ExprHeightCalculator.java",
],
tags = [
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
package dev.cel.common.navigation;

import com.google.auto.value.AutoValue;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.CheckReturnValue;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelComprehension;
import dev.cel.common.ast.CelExpr.ExprKind;
Expand Down Expand Up @@ -69,7 +67,12 @@ public long id() {

/** Constructs a new instance of {@link CelNavigableExpr} from {@link CelExpr}. */
public static CelNavigableExpr fromExpr(CelExpr expr) {
return CelNavigableExpr.builder().setExpr(expr).build();
ExprHeightCalculator exprHeightCalculator = new ExprHeightCalculator(expr);

return CelNavigableExpr.builder()
.setExpr(expr)
.setHeight(exprHeightCalculator.getHeight(expr.id()))
.build();
}

/**
Expand Down Expand Up @@ -136,8 +139,6 @@ public static Builder builder() {
@AutoValue.Builder
public abstract static class Builder {

private Builder parentBuilder;

public abstract CelExpr expr();

public abstract int depth();
Expand All @@ -150,25 +151,11 @@ public ExprKind.Kind getKind() {

abstract Builder setParent(CelNavigableExpr value);

@CanIgnoreReturnValue
public Builder setParentBuilder(CelNavigableExpr.Builder value) {
parentBuilder = value;
return this;
}

public abstract Builder setDepth(int value);

public abstract Builder setHeight(int value);

public abstract CelNavigableExpr autoBuild();

@CheckReturnValue
public CelNavigableExpr build() {
if (parentBuilder != null) {
setParent(parentBuilder.build());
}
return autoBuild();
}
public abstract CelNavigableExpr build();
}

public abstract Builder toBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package dev.cel.common.navigation;

import static java.lang.Math.max;

import com.google.common.collect.ImmutableList;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
Expand All @@ -31,12 +29,15 @@
final class CelNavigableExprVisitor {
private static final int MAX_DESCENDANTS_RECURSION_DEPTH = 500;

private final Stream.Builder<CelNavigableExpr.Builder> streamBuilder;
private final Stream.Builder<CelNavigableExpr> streamBuilder;
private final ExprHeightCalculator exprHeightCalculator;
private final TraversalOrder traversalOrder;
private final int maxDepth;

private CelNavigableExprVisitor(int maxDepth, TraversalOrder traversalOrder) {
private CelNavigableExprVisitor(
int maxDepth, ExprHeightCalculator exprHeightCalculator, TraversalOrder traversalOrder) {
this.maxDepth = maxDepth;
this.exprHeightCalculator = exprHeightCalculator;
this.traversalOrder = traversalOrder;
this.streamBuilder = Stream.builder();
}
Expand Down Expand Up @@ -84,14 +85,16 @@ static Stream<CelNavigableExpr> collect(
*/
static Stream<CelNavigableExpr> collect(
CelNavigableExpr navigableExpr, int maxDepth, TraversalOrder traversalOrder) {
CelNavigableExprVisitor visitor = new CelNavigableExprVisitor(maxDepth, traversalOrder);
ExprHeightCalculator exprHeightCalculator = new ExprHeightCalculator(navigableExpr.expr());
CelNavigableExprVisitor visitor =
new CelNavigableExprVisitor(maxDepth, exprHeightCalculator, traversalOrder);

visitor.visit(navigableExpr.toBuilder());
visitor.visit(navigableExpr);

return visitor.streamBuilder.build().map(CelNavigableExpr.Builder::build);
return visitor.streamBuilder.build();
}

private int visit(CelNavigableExpr.Builder navigableExpr) {
private void visit(CelNavigableExpr navigableExpr) {
if (navigableExpr.depth() > MAX_DESCENDANTS_RECURSION_DEPTH - 1) {
throw new IllegalStateException("Max recursion depth reached.");
}
Expand All @@ -101,108 +104,89 @@ private int visit(CelNavigableExpr.Builder navigableExpr) {
streamBuilder.add(navigableExpr);
}

int height = 1;
switch (navigableExpr.getKind()) {
case CALL:
height += visit(navigableExpr, navigableExpr.expr().call());
visit(navigableExpr, navigableExpr.expr().call());
break;
case CREATE_LIST:
height += visit(navigableExpr, navigableExpr.expr().createList());
visit(navigableExpr, navigableExpr.expr().createList());
break;
case SELECT:
height += visit(navigableExpr, navigableExpr.expr().select());
visit(navigableExpr, navigableExpr.expr().select());
break;
case CREATE_STRUCT:
height += visitStruct(navigableExpr, navigableExpr.expr().createStruct());
visitStruct(navigableExpr, navigableExpr.expr().createStruct());
break;
case CREATE_MAP:
height += visitMap(navigableExpr, navigableExpr.expr().createMap());
visitMap(navigableExpr, navigableExpr.expr().createMap());
break;
case COMPREHENSION:
height += visit(navigableExpr, navigableExpr.expr().comprehension());
visit(navigableExpr, navigableExpr.expr().comprehension());
break;
default:
// This is a leaf node
height = 0;
break;
}

navigableExpr.setHeight(height);
if (addToStream && traversalOrder.equals(TraversalOrder.POST_ORDER)) {
streamBuilder.add(navigableExpr);
}

return height;
}

private int visit(CelNavigableExpr.Builder navigableExpr, CelCall call) {
int targetHeight = 0;
private void visit(CelNavigableExpr navigableExpr, CelCall call) {
if (call.target().isPresent()) {
CelNavigableExpr.Builder target = newNavigableChild(navigableExpr, call.target().get());
targetHeight = visit(target);
visit(newNavigableChild(navigableExpr, call.target().get()));
}

int argumentHeight = visitExprList(call.args(), navigableExpr);
return max(targetHeight, argumentHeight);
visitExprList(call.args(), navigableExpr);
}

private int visit(CelNavigableExpr.Builder navigableExpr, CelCreateList createList) {
return visitExprList(createList.elements(), navigableExpr);
private void visit(CelNavigableExpr navigableExpr, CelCreateList createList) {
visitExprList(createList.elements(), navigableExpr);
}

private int visit(CelNavigableExpr.Builder navigableExpr, CelSelect selectExpr) {
CelNavigableExpr.Builder operand = newNavigableChild(navigableExpr, selectExpr.operand());
return visit(operand);
private void visit(CelNavigableExpr navigableExpr, CelSelect selectExpr) {
CelNavigableExpr operand = newNavigableChild(navigableExpr, selectExpr.operand());
visit(operand);
}

private int visit(CelNavigableExpr.Builder navigableExpr, CelComprehension comprehension) {
int maxHeight = 0;
maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.iterRange())), maxHeight);
maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.accuInit())), maxHeight);
maxHeight =
max(visit(newNavigableChild(navigableExpr, comprehension.loopCondition())), maxHeight);
maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.loopStep())), maxHeight);
maxHeight = max(visit(newNavigableChild(navigableExpr, comprehension.result())), maxHeight);

return maxHeight;
private void visit(CelNavigableExpr navigableExpr, CelComprehension comprehension) {
visit(newNavigableChild(navigableExpr, comprehension.iterRange()));
visit(newNavigableChild(navigableExpr, comprehension.accuInit()));
visit(newNavigableChild(navigableExpr, comprehension.loopCondition()));
visit(newNavigableChild(navigableExpr, comprehension.loopStep()));
visit(newNavigableChild(navigableExpr, comprehension.result()));
}

private int visitStruct(CelNavigableExpr.Builder navigableExpr, CelCreateStruct struct) {
int maxHeight = 0;
private void visitStruct(CelNavigableExpr navigableExpr, CelCreateStruct struct) {
for (CelCreateStruct.Entry entry : struct.entries()) {
CelNavigableExpr.Builder value = newNavigableChild(navigableExpr, entry.value());
maxHeight = max(visit(value), maxHeight);
visit(newNavigableChild(navigableExpr, entry.value()));
}
return maxHeight;
}

private int visitMap(CelNavigableExpr.Builder navigableExpr, CelCreateMap map) {
int maxHeight = 0;
private void visitMap(CelNavigableExpr navigableExpr, CelCreateMap map) {
for (CelCreateMap.Entry entry : map.entries()) {
CelNavigableExpr.Builder key = newNavigableChild(navigableExpr, entry.key());
maxHeight = max(visit(key), maxHeight);
CelNavigableExpr key = newNavigableChild(navigableExpr, entry.key());
visit(key);

CelNavigableExpr.Builder value = newNavigableChild(navigableExpr, entry.value());
maxHeight = max(visit(value), maxHeight);
CelNavigableExpr value = newNavigableChild(navigableExpr, entry.value());
visit(value);
}
return 0;
}

private int visitExprList(
ImmutableList<CelExpr> createListExpr, CelNavigableExpr.Builder parent) {
int maxHeight = 0;
private void visitExprList(ImmutableList<CelExpr> createListExpr, CelNavigableExpr parent) {
for (CelExpr expr : createListExpr) {
CelNavigableExpr.Builder arg = newNavigableChild(parent, expr);
maxHeight = max(visit(arg), maxHeight);
visit(newNavigableChild(parent, expr));
}
return maxHeight;
}

private CelNavigableExpr.Builder newNavigableChild(
CelNavigableExpr.Builder parent, CelExpr expr) {
return CelNavigableExpr.builder()
.setExpr(expr)
.setDepth(parent.depth() + 1)
.setParentBuilder(parent);
private CelNavigableExpr newNavigableChild(CelNavigableExpr parent, CelExpr expr) {
CelNavigableExpr.Builder navigableExpr =
CelNavigableExpr.builder()
.setExpr(expr)
.setDepth(parent.depth() + 1)
.setHeight(exprHeightCalculator.getHeight(expr.id()))
.setParent(parent);

return navigableExpr.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package dev.cel.common.navigation;

import static java.lang.Math.max;

import com.google.common.collect.ImmutableList;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
import dev.cel.common.ast.CelExpr.CelComprehension;
import dev.cel.common.ast.CelExpr.CelCreateList;
import dev.cel.common.ast.CelExpr.CelCreateMap;
import dev.cel.common.ast.CelExpr.CelCreateStruct;
import dev.cel.common.ast.CelExpr.CelSelect;
import java.util.HashMap;

/** Package-private class to assist computing the height of expression nodes. */
final class ExprHeightCalculator {
// Store hashmap instead of immutable map for performance, such that this helper class can be
// instantiated faster.
private final HashMap<Long, Integer> idToHeight;

ExprHeightCalculator(CelExpr celExpr) {
this.idToHeight = new HashMap<>();
visit(celExpr);
}

int getHeight(Long exprId) {
if (!idToHeight.containsKey(exprId)) {
throw new IllegalStateException("Height not found for expression id: " + exprId);
}

return idToHeight.get(exprId);
}

private int visit(CelExpr celExpr) {
int height = 1;
switch (celExpr.exprKind().getKind()) {
case CALL:
height += visit(celExpr.call());
break;
case CREATE_LIST:
height += visit(celExpr.createList());
break;
case SELECT:
height += visit(celExpr.select());
break;
case CREATE_STRUCT:
height += visitStruct(celExpr.createStruct());
break;
case CREATE_MAP:
height += visitMap(celExpr.createMap());
break;
case COMPREHENSION:
height += visit(celExpr.comprehension());
break;
default:
// This is a leaf node
height = 0;
break;
}

idToHeight.put(celExpr.id(), height);
return height;
}

private int visit(CelCall call) {
int targetHeight = 0;
if (call.target().isPresent()) {
targetHeight = visit(call.target().get());
}

int argumentHeight = visitExprList(call.args());
return max(targetHeight, argumentHeight);
}

private int visit(CelCreateList createList) {
return visitExprList(createList.elements());
}

private int visit(CelSelect selectExpr) {
return visit(selectExpr.operand());
}

private int visit(CelComprehension comprehension) {
int maxHeight = 0;
maxHeight = max(visit(comprehension.iterRange()), maxHeight);
maxHeight = max(visit(comprehension.accuInit()), maxHeight);
maxHeight = max(visit(comprehension.loopCondition()), maxHeight);
maxHeight = max(visit(comprehension.loopStep()), maxHeight);
maxHeight = max(visit(comprehension.result()), maxHeight);

return maxHeight;
}

private int visitStruct(CelCreateStruct struct) {
int maxHeight = 0;
for (CelCreateStruct.Entry entry : struct.entries()) {
maxHeight = max(visit(entry.value()), maxHeight);
}
return maxHeight;
}

private int visitMap(CelCreateMap map) {
int maxHeight = 0;
for (CelCreateMap.Entry entry : map.entries()) {
maxHeight = max(visit(entry.key()), maxHeight);
maxHeight = max(visit(entry.value()), maxHeight);
}
return maxHeight;
}

private int visitExprList(ImmutableList<CelExpr> createListExpr) {
int maxHeight = 0;
for (CelExpr expr : createListExpr) {
maxHeight = max(visit(expr), maxHeight);
}
return maxHeight;
}
}
Loading

0 comments on commit 4f8f455

Please sign in to comment.