diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index a11a6a93..218435d9 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -62,7 +62,8 @@ private enum BindingTestCase { + " \"aaaa:aaaa\""), BIND_WITH_EXISTS_TRUE( "cel.bind(valid_elems, [1, 2, 3], [3, 4, 5].exists(e, e in valid_elems))"), - BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))"); + BIND_WITH_EXISTS_FALSE("cel.bind(valid_elems, [1, 2, 3], ![4, 5].exists(e, e in valid_elems))"), + BIND_WITH_MAP("[1,2,3].map(x, cel.bind(y, x + x, [y, y])) == [[2, 2], [4, 4], [6, 6]]"); private final String source; diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java index f21cd18b..f27882ed 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java @@ -275,11 +275,6 @@ private IntermediateResult resolveIdent(ExecutionFrame frame, CelExpr expr, Stri return IntermediateResult.create(typeValue); } - IntermediateResult cachedResult = frame.lookupLazilyEvaluatedResult(name).orElse(null); - if (cachedResult != null) { - return cachedResult; - } - IntermediateResult rawResult = frame.resolveSimpleName(name, expr.id()); Object value = rawResult.value(); boolean isLazyExpression = value instanceof LazyExpression; @@ -885,7 +880,6 @@ private static class ExecutionFrame { private final CelEvaluationListener evaluationListener; private final int maxIterations; private final ArrayDeque resolvers; - private final Map lazyEvalResultCache; private RuntimeUnknownResolver currentResolver; private int iterations; @@ -898,7 +892,6 @@ private ExecutionFrame( this.resolvers.add(resolver); this.currentResolver = resolver; this.maxIterations = maxIterations; - this.lazyEvalResultCache = new HashMap<>(); } private CelEvaluationListener getEvaluationListener() { @@ -929,12 +922,9 @@ private Optional resolveAttribute(CelAttribute attr) { return currentResolver.resolveAttribute(attr); } - private Optional lookupLazilyEvaluatedResult(String name) { - return Optional.ofNullable(lazyEvalResultCache.get(name)); - } - - private void cacheLazilyEvaluatedResult(String name, IntermediateResult result) { - lazyEvalResultCache.put(name, result); + private void cacheLazilyEvaluatedResult( + String name, DefaultInterpreter.IntermediateResult result) { + currentResolver.cacheLazilyEvaluatedResult(name, result); } private void pushScope(ImmutableMap scope) { diff --git a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java index e68270ca..1c1be9b9 100644 --- a/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java +++ b/runtime/src/main/java/dev/cel/runtime/RuntimeUnknownResolver.java @@ -16,6 +16,7 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import dev.cel.common.annotations.Internal; +import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -33,6 +34,7 @@ public class RuntimeUnknownResolver { /** The underlying resolver for known values. */ private final GlobalResolver resolver; + /** Resolver for unknown and resolved attributes. */ private final CelAttributeResolver attributeResolver; @@ -112,6 +114,10 @@ DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId attr, InterpreterUtil.valueOrUnknown(result, exprId)); } + void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResult result) { + // no-op. Caching is handled in ScopedResolver. + } + /** * Attempt to resolve an attribute bound to a context variable. This is used to shadow lazily * resolved values behind field accesses and index operations. @@ -127,6 +133,7 @@ ScopedResolver withScope(Map vars static final class ScopedResolver extends RuntimeUnknownResolver { private final RuntimeUnknownResolver parent; private final Map shadowedVars; + private final Map lazyEvalResultCache; private ScopedResolver( RuntimeUnknownResolver parent, @@ -134,16 +141,26 @@ private ScopedResolver( super(parent.resolver, parent.attributeResolver, parent.attributeTrackingEnabled); this.parent = parent; this.shadowedVars = shadowedVars; + this.lazyEvalResultCache = new HashMap<>(); } @Override DefaultInterpreter.IntermediateResult resolveSimpleName(String name, Long exprId) { - DefaultInterpreter.IntermediateResult shadowed = shadowedVars.get(name); - if (shadowed != null) { - return shadowed; + DefaultInterpreter.IntermediateResult result = lazyEvalResultCache.get(name); + if (result != null) { + return result; + } + result = shadowedVars.get(name); + if (result != null) { + return result; } return parent.resolveSimpleName(name, exprId); } + + @Override + void cacheLazilyEvaluatedResult(String name, DefaultInterpreter.IntermediateResult result) { + lazyEvalResultCache.put(name, result); + } } /** Null implementation for attribute resolution. */