Skip to content

Commit

Permalink
Accept eliminable custom functions as an option
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610849935
  • Loading branch information
l46kok authored and copybara-github committed Feb 27, 2024
1 parent dc9083b commit 9ebad48
Show file tree
Hide file tree
Showing 16 changed files with 1,998 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ java_library(
"//optimizer:ast_optimizer",
"//optimizer:mutable_ast",
"//parser:operator",
"@maven//:com_google_errorprone_error_prone_annotations",
"@maven//:com_google_guava_guava",
"@maven//:org_jspecify_jspecify",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package dev.cel.optimizer.optimizers;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Arrays.stream;
Expand All @@ -25,6 +26,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import dev.cel.bundle.CelBuilder;
import dev.cel.checker.Standard;
import dev.cel.common.CelAbstractSyntaxTree;
Expand Down Expand Up @@ -52,6 +54,7 @@
import dev.cel.optimizer.MutableAst.MangledComprehensionAst;
import dev.cel.parser.Operator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -84,25 +87,25 @@
* </pre>
*/
public class SubexpressionOptimizer implements CelAstOptimizer {
private static final ImmutableSet<String> CSE_DEFAULT_ELIMINABLE_FUNCTIONS =
Streams.concat(
stream(Operator.values()).map(Operator::getFunction),
stream(Standard.Function.values()).map(Standard.Function::getFunction),
stream(CelOptionalLibrary.Function.values()).map(Function::getFunction))
.collect(toImmutableSet());
private static final SubexpressionOptimizer INSTANCE =
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
private static final String BIND_IDENTIFIER_PREFIX = "@r";
private static final String MANGLED_COMPREHENSION_IDENTIFIER_PREFIX = "@c";
private static final String MANGLED_COMPREHENSION_RESULT_PREFIX = "@x";
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
private static final String BLOCK_INDEX_PREFIX = "@index";
private static final ImmutableSet<String> CSE_ALLOWED_FUNCTIONS =
Streams.concat(
stream(Operator.values()).map(Operator::getFunction),
stream(Standard.Function.values()).map(Standard.Function::getFunction),
stream(CelOptionalLibrary.Function.values()).map(Function::getFunction))
.collect(toImmutableSet());

private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);

private final SubexpressionOptimizerOptions cseOptions;
private final MutableAst mutableAst;
private final ImmutableSet<String> cseEliminableFunctions;

/**
* Returns a default instance of common subexpression elimination optimizer with preconfigured
Expand Down Expand Up @@ -359,7 +362,7 @@ private Stream<CelExpr> getAllCseCandidatesStream(
return CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes()
.filter(SubexpressionOptimizer::canEliminate)
.filter(this::canEliminate)
.map(CelNavigableExpr::expr)
.filter(expr -> areSemanticallyEqual(cseCandidate, expr));
}
Expand Down Expand Up @@ -423,7 +426,7 @@ private Optional<CelNavigableExpr> findCseCandidateWithRecursionDepth(
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes(TraversalOrder.POST_ORDER)
.filter(SubexpressionOptimizer::canEliminate)
.filter(this::canEliminate)
.filter(node -> node.height() <= recursionLimit)
.filter(node -> !areSemanticallyEqual(ast.getExpr(), node.expr()))
.collect(toImmutableList());
Expand Down Expand Up @@ -462,18 +465,18 @@ private Optional<CelNavigableExpr> findCseCandidateWithCommonSubexpr(CelAbstract
CelNavigableAst.fromAst(ast)
.getRoot()
.allNodes(TraversalOrder.PRE_ORDER)
.filter(SubexpressionOptimizer::canEliminate)
.filter(this::canEliminate)
.collect(toImmutableList());

return findCseCandidateWithCommonSubexpr(allNodes);
}

private static boolean canEliminate(CelNavigableExpr navigableExpr) {
private boolean canEliminate(CelNavigableExpr navigableExpr) {
return !navigableExpr.getKind().equals(Kind.CONSTANT)
&& !navigableExpr.getKind().equals(Kind.IDENT)
&& !navigableExpr.expr().identOrDefault().name().startsWith(BIND_IDENTIFIER_PREFIX)
&& !navigableExpr.expr().selectOrDefault().testOnly()
&& containsAllowedFunctionOnly(navigableExpr)
&& containsEliminableFunctionOnly(navigableExpr)
&& isWithinInlineableComprehension(navigableExpr);
}

Expand Down Expand Up @@ -507,13 +510,13 @@ private boolean areSemanticallyEqual(CelExpr expr1, CelExpr expr2) {
return normalizeForEquality(expr1).equals(normalizeForEquality(expr2));
}

private static boolean containsAllowedFunctionOnly(CelNavigableExpr navigableExpr) {
private boolean containsEliminableFunctionOnly(CelNavigableExpr navigableExpr) {
return navigableExpr
.allNodes()
.allMatch(
node -> {
if (node.getKind().equals(Kind.CALL)) {
return CSE_ALLOWED_FUNCTIONS.contains(node.expr().call().function());
return cseEliminableFunctions.contains(node.expr().call().function());
}

return true;
Expand Down Expand Up @@ -580,6 +583,8 @@ public abstract static class SubexpressionOptimizerOptions {

public abstract int subexpressionMaxRecursionDepth();

public abstract ImmutableSet<String> eliminableFunctions();

/** Builder for configuring the {@link SubexpressionOptimizerOptions}. */
@AutoValue.Builder
public abstract static class Builder {
Expand Down Expand Up @@ -630,11 +635,34 @@ public abstract static class Builder {
*/
public abstract Builder subexpressionMaxRecursionDepth(int value);

abstract ImmutableSet.Builder<String> eliminableFunctionsBuilder();

/**
* Adds a collection of custom functions that will be a candidate for common subexpression
* elimination. By default, standard functions are eliminable.
*
* <p>Note that the implementation of custom functions must be free of side effects.
*/
@CanIgnoreReturnValue
public Builder addEliminableFunctions(Iterable<String> functions) {
checkNotNull(functions);
this.eliminableFunctionsBuilder().addAll(functions);
return this;
}

/** See {@link #addEliminableFunctions(Iterable)}. */
@CanIgnoreReturnValue
public Builder addEliminableFunctions(String... functions) {
return addEliminableFunctions(Arrays.asList(functions));
}

public abstract SubexpressionOptimizerOptions build();

Builder() {}
}

abstract Builder toBuilder();

/** Returns a new options builder with recommended defaults pre-configured. */
public static Builder newBuilder() {
return new AutoValue_SubexpressionOptimizer_SubexpressionOptimizerOptions.Builder()
Expand All @@ -650,5 +678,10 @@ public static Builder newBuilder() {
private SubexpressionOptimizer(SubexpressionOptimizerOptions cseOptions) {
this.cseOptions = cseOptions;
this.mutableAst = MutableAst.newInstance(cseOptions.iterationLimit());
this.cseEliminableFunctions =
ImmutableSet.<String>builder()
.addAll(CSE_DEFAULT_ELIMINABLE_FUNCTIONS)
.addAll(cseOptions.eliminableFunctions())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import dev.cel.parser.CelStandardMacro;
import dev.cel.parser.CelUnparser;
import dev.cel.parser.CelUnparserFactory;
import dev.cel.runtime.CelRuntime.CelFunctionBinding;
import dev.cel.testing.BaselineTestCase;
import dev.cel.testing.testdata.proto3.TestAllTypesProto.NestedTestAllTypes;
import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes;
Expand Down Expand Up @@ -68,6 +69,13 @@ public class SubexpressionOptimizerBaselineTest extends BaselineTestCase {
.build();
private static final Cel CEL = newCelBuilder().build();

private static final SubexpressionOptimizerOptions OPTIMIZER_COMMON_OPTIONS =
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.addEliminableFunctions("pure_custom_func")
.build();

private String overriddenBaseFilePath = "";

@Before
Expand Down Expand Up @@ -287,8 +295,16 @@ private static CelBuilder newCelBuilder() {
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"custom_func",
newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT)))
"pure_custom_func",
newGlobalOverload("pure_custom_func_overload", SimpleType.INT, SimpleType.INT)),
CelFunctionDecl.newFunctionDeclaration(
"non_pure_custom_func",
newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT)))
.addFunctionBindings(
// This is pure, but for the purposes of excluding it as a CSE candidate, pretend that
// it isn't.
CelFunctionBinding.from("non_pure_custom_func_overload", Long.class, val -> val),
CelFunctionBinding.from("pure_custom_func_overload", Long.class, val -> val))
.addVar("x", SimpleType.DYN)
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
Expand All @@ -302,70 +318,26 @@ private static CelOptimizer newCseOptimizer(Cel cel, SubexpressionOptimizerOptio

@SuppressWarnings("Immutable") // Test only
private enum CseTestOptimizer {
CASCADED_BINDS(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(false)
.build()),
BLOCK_COMMON_SUBEXPR_ONLY(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.build()),
CASCADED_BINDS(OPTIMIZER_COMMON_OPTIONS.toBuilder().enableCelBlock(false).build()),
BLOCK_COMMON_SUBEXPR_ONLY(OPTIMIZER_COMMON_OPTIONS),
BLOCK_RECURSION_DEPTH_1(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(1)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(1).build()),
BLOCK_RECURSION_DEPTH_2(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(2)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(2).build()),
BLOCK_RECURSION_DEPTH_3(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(3)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(3).build()),
BLOCK_RECURSION_DEPTH_4(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(4)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(4).build()),
BLOCK_RECURSION_DEPTH_5(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(5)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(5).build()),
BLOCK_RECURSION_DEPTH_6(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(6)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(6).build()),
BLOCK_RECURSION_DEPTH_7(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(7)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(7).build()),
BLOCK_RECURSION_DEPTH_8(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(8)
.build()),
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(8).build()),
BLOCK_RECURSION_DEPTH_9(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.subexpressionMaxRecursionDepth(9)
.build());
OPTIMIZER_COMMON_OPTIONS.toBuilder().subexpressionMaxRecursionDepth(9).build());

private final CelOptimizer celOptimizer;

Expand Down Expand Up @@ -507,7 +479,16 @@ private enum CseTestCase {
"('h' + 'e' + 'l' + 'l' + 'o' + ' world').matches('hello')"),
CALL_BOTH_ARGUMENT_TARGET_NESTED_NO_COMMON_SUBEXPR(
"('h' + 'e' + 'l' + 'l' + 'o' + ' world').matches('w' + 'o' + 'r' + 'l' + 'd')"),
;
CUSTOM_FUNCTION_INELIMINABLE(
"non_pure_custom_func(msg.oneof_type.payload.single_int64) +"
+ " non_pure_custom_func(msg.oneof_type.payload.single_int32) +"
+ " non_pure_custom_func(msg.oneof_type.payload.single_int64) +"
+ " non_pure_custom_func(msg.single_int64)"),
CUSTOM_FUNCTION_ELIMINABLE(
"pure_custom_func(msg.oneof_type.payload.single_int64) +"
+ " pure_custom_func(msg.oneof_type.payload.single_int32) +"
+ " pure_custom_func(msg.oneof_type.payload.single_int64) +"
+ " pure_custom_func(msg.single_int64)");
private final String source;

CseTestCase(String source) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.types.ListType;
import dev.cel.common.types.OptionalType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
import dev.cel.extensions.CelExtensions;
import dev.cel.extensions.CelOptionalLibrary;
import dev.cel.optimizer.CelOptimizationException;
import dev.cel.optimizer.CelOptimizer;
import dev.cel.optimizer.CelOptimizerFactory;
Expand Down Expand Up @@ -107,18 +105,15 @@ public class SubexpressionOptimizerTest {
private static CelBuilder newCelBuilder() {
return CelFactory.standardCelBuilder()
.addMessageTypes(TestAllTypes.getDescriptor())
.setContainer("dev.cel.testing.testdata.proto3")
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.setOptions(
CelOptions.current().enableTimestampEpoch(true).populateMacroCalls(true).build())
.addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings())
.addRuntimeLibraries(CelOptionalLibrary.INSTANCE)
.addCompilerLibraries(CelExtensions.bindings())
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
"custom_func",
newGlobalOverload("custom_func_overload", SimpleType.INT, SimpleType.INT)))
"non_pure_custom_func",
newGlobalOverload("non_pure_custom_func_overload", SimpleType.INT, SimpleType.INT)))
.addVar("x", SimpleType.DYN)
.addVar("opt_x", OptionalType.create(SimpleType.DYN))
.addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName()));
}

Expand Down Expand Up @@ -156,9 +151,10 @@ private enum CseNoOpTestCase {
// Constants and identifiers within a function
CONST_WITHIN_FUNCTION("size(\"hello\" + \"hello\" + \"hello\")"),
IDENT_WITHIN_FUNCTION("string(x + x + x)"),
// Non-standard functions are considered non-pure for time being
NON_STANDARD_FUNCTION_1("custom_func(1) + custom_func(1)"),
NON_STANDARD_FUNCTION_2("1 + custom_func(1) + 1 + custom_func(1)"),
// Non-standard functions that have not been explicitly added as a candidate are not
// optimized.
NON_STANDARD_FUNCTION_1("non_pure_custom_func(1) + non_pure_custom_func(1)"),
NON_STANDARD_FUNCTION_2("1 + non_pure_custom_func(1) + 1 + non_pure_custom_func(1)"),
// Duplicated but nested calls.
NESTED_FUNCTION("int(timestamp(int(timestamp(1000000000))))"),
// This cannot be optimized. Extracting the common subexpression would presence test
Expand Down
Loading

0 comments on commit 9ebad48

Please sign in to comment.