Skip to content

Commit

Permalink
Fix replacing namespaced identifiers for accu_init
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611217621
  • Loading branch information
l46kok authored and copybara-github committed Feb 28, 2024
1 parent 776f95a commit 73d29cf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
33 changes: 22 additions & 11 deletions checker/src/main/java/dev/cel/checker/ExprChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,8 @@ private CelExpr visit(CelExpr expr, CelExpr.CelCreateList createList) {
@CheckReturnValue
private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
CelExpr visitedRange = visit(compre.iterRange());
if (namespacedDeclarations && !visitedRange.equals(compre.iterRange())) {
expr = replaceComprehensionRangeSubtree(expr, visitedRange);
}
CelExpr init = visit(compre.accuInit());
CelType accuType = env.getType(init);
CelExpr visitedInit = visit(compre.accuInit());
CelType accuType = env.getType(visitedInit);
CelType rangeType = inferenceContext.specialize(env.getType(visitedRange));
CelType varType;
switch (rangeType.kind()) {
Expand Down Expand Up @@ -533,17 +530,25 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
CelExpr condition = visit(compre.loopCondition());
assertType(condition, SimpleType.BOOL);
CelExpr visitedStep = visit(compre.loopStep());
if (namespacedDeclarations && !visitedStep.equals(compre.loopStep())) {
expr = replaceComprehensionStepSubtree(expr, visitedStep);
}
assertType(visitedStep, accuType);
// Forget iteration variable, as result expression must only depend on accu.
env.exitScope();
CelExpr visitedResult = visit(compre.result());
if (namespacedDeclarations && !visitedResult.equals(compre.result())) {
expr = replaceComprehensionResultSubtree(expr, visitedResult);
}
env.exitScope();
if (namespacedDeclarations) {
if (!visitedRange.equals(compre.iterRange())) {
expr = replaceComprehensionRangeSubtree(expr, visitedRange);
}
if (!visitedInit.equals(compre.accuInit())) {
expr = replaceComprehensionAccuInitSubtree(expr, visitedInit);
}
if (!visitedStep.equals(compre.loopStep())) {
expr = replaceComprehensionStepSubtree(expr, visitedStep);
}
if (!visitedResult.equals(compre.result())) {
expr = replaceComprehensionResultSubtree(expr, visitedResult);
}
}
env.setType(expr, inferenceContext.specialize(env.getType(visitedResult)));
return expr;
}
Expand Down Expand Up @@ -872,6 +877,12 @@ private static CelExpr replaceMapEntryValueSubtree(CelExpr expr, CelExpr newValu
return expr.toBuilder().setCreateMap(createMap).build();
}

private static CelExpr replaceComprehensionAccuInitSubtree(CelExpr expr, CelExpr newAccuInit) {
CelExpr.CelComprehension newComprehension =
expr.comprehension().toBuilder().setAccuInit(newAccuInit).build();
return expr.toBuilder().setComprehension(newComprehension).build();
}

private static CelExpr replaceComprehensionRangeSubtree(CelExpr expr, CelExpr newRange) {
CelExpr.CelComprehension newComprehension =
expr.comprehension().toBuilder().setIterRange(newRange).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ public final class CelBindingsExtensionsTest {
private static final CelCompiler COMPILER =
CelCompilerFactory.standardCelCompilerBuilder()
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.addLibraries(CelExtensions.bindings())
.addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings())
.build();

private static final CelRuntime RUNTIME = CelRuntimeFactory.standardCelRuntimeBuilder().build();
private static final CelRuntime RUNTIME =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addLibraries(CelOptionalLibrary.INSTANCE)
.build();

private enum BindingTestCase {
BOOL_LITERAL("cel.bind(t, true, t)"),
Expand All @@ -63,7 +66,8 @@ private enum BindingTestCase {
BIND_WITH_EXISTS_TRUE(
"cel.bind(valid_elems, [1, 2, 3], [3, 4, 5].exists(e, e in valid_elems))"),
BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))"),
BIND_WITH_MAP("[1,2,3].map(x, cel.bind(y, x + x, [y, y])) == [[2, 2], [4, 4], [6, 6]]");
BIND_WITH_MAP("[1,2,3].map(x, cel.bind(y, x + x, [y, y])) == [[2, 2], [4, 4], [6, 6]]"),
BIND_OPTIONAL_LIST("cel.bind(r0, optional.none(), [?r0, ?r0]) == []");

private final String source;

Expand Down

0 comments on commit 73d29cf

Please sign in to comment.