From 73d29cf6e87cfd303adaa51e4517d3b98cabd712 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 28 Feb 2024 13:51:04 -0800 Subject: [PATCH] Fix replacing namespaced identifiers for accu_init PiperOrigin-RevId: 611217621 --- .../java/dev/cel/checker/ExprChecker.java | 33 ++++++++++++------- .../extensions/CelBindingsExtensionsTest.java | 10 ++++-- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/checker/src/main/java/dev/cel/checker/ExprChecker.java b/checker/src/main/java/dev/cel/checker/ExprChecker.java index 5c0c21ff..2506466c 100644 --- a/checker/src/main/java/dev/cel/checker/ExprChecker.java +++ b/checker/src/main/java/dev/cel/checker/ExprChecker.java @@ -487,11 +487,8 @@ private CelExpr visit(CelExpr expr, CelExpr.CelCreateList createList) { @CheckReturnValue private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) { CelExpr visitedRange = visit(compre.iterRange()); - if (namespacedDeclarations && !visitedRange.equals(compre.iterRange())) { - expr = replaceComprehensionRangeSubtree(expr, visitedRange); - } - CelExpr init = visit(compre.accuInit()); - CelType accuType = env.getType(init); + CelExpr visitedInit = visit(compre.accuInit()); + CelType accuType = env.getType(visitedInit); CelType rangeType = inferenceContext.specialize(env.getType(visitedRange)); CelType varType; switch (rangeType.kind()) { @@ -533,17 +530,25 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) { CelExpr condition = visit(compre.loopCondition()); assertType(condition, SimpleType.BOOL); CelExpr visitedStep = visit(compre.loopStep()); - if (namespacedDeclarations && !visitedStep.equals(compre.loopStep())) { - expr = replaceComprehensionStepSubtree(expr, visitedStep); - } assertType(visitedStep, accuType); // Forget iteration variable, as result expression must only depend on accu. env.exitScope(); CelExpr visitedResult = visit(compre.result()); - if (namespacedDeclarations && !visitedResult.equals(compre.result())) { - expr = replaceComprehensionResultSubtree(expr, visitedResult); - } env.exitScope(); + if (namespacedDeclarations) { + if (!visitedRange.equals(compre.iterRange())) { + expr = replaceComprehensionRangeSubtree(expr, visitedRange); + } + if (!visitedInit.equals(compre.accuInit())) { + expr = replaceComprehensionAccuInitSubtree(expr, visitedInit); + } + if (!visitedStep.equals(compre.loopStep())) { + expr = replaceComprehensionStepSubtree(expr, visitedStep); + } + if (!visitedResult.equals(compre.result())) { + expr = replaceComprehensionResultSubtree(expr, visitedResult); + } + } env.setType(expr, inferenceContext.specialize(env.getType(visitedResult))); return expr; } @@ -872,6 +877,12 @@ private static CelExpr replaceMapEntryValueSubtree(CelExpr expr, CelExpr newValu return expr.toBuilder().setCreateMap(createMap).build(); } + private static CelExpr replaceComprehensionAccuInitSubtree(CelExpr expr, CelExpr newAccuInit) { + CelExpr.CelComprehension newComprehension = + expr.comprehension().toBuilder().setAccuInit(newAccuInit).build(); + return expr.toBuilder().setComprehension(newComprehension).build(); + } + private static CelExpr replaceComprehensionRangeSubtree(CelExpr expr, CelExpr newRange) { CelExpr.CelComprehension newComprehension = expr.comprehension().toBuilder().setIterRange(newRange).build(); diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index 218435d9..1eba6e22 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -46,10 +46,13 @@ public final class CelBindingsExtensionsTest { private static final CelCompiler COMPILER = CelCompilerFactory.standardCelCompilerBuilder() .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addLibraries(CelExtensions.bindings()) + .addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) .build(); - private static final CelRuntime RUNTIME = CelRuntimeFactory.standardCelRuntimeBuilder().build(); + private static final CelRuntime RUNTIME = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addLibraries(CelOptionalLibrary.INSTANCE) + .build(); private enum BindingTestCase { BOOL_LITERAL("cel.bind(t, true, t)"), @@ -63,7 +66,8 @@ private enum BindingTestCase { BIND_WITH_EXISTS_TRUE( "cel.bind(valid_elems, [1, 2, 3], [3, 4, 5].exists(e, e in valid_elems))"), BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))"), - BIND_WITH_MAP("[1,2,3].map(x, cel.bind(y, x + x, [y, y])) == [[2, 2], [4, 4], [6, 6]]"); + BIND_WITH_MAP("[1,2,3].map(x, cel.bind(y, x + x, [y, y])) == [[2, 2], [4, 4], [6, 6]]"), + BIND_OPTIONAL_LIST("cel.bind(r0, optional.none(), [?r0, ?r0]) == []"); private final String source;