Skip to content

Commit

Permalink
Add ConstantFoldingOptions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599969298
  • Loading branch information
l46kok authored and copybara-github committed Jan 20, 2024
1 parent f9f370d commit 7d28e89
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ java_library(
tags = [
],
deps = [
"//:auto_value",
"//bundle:cel",
"//common",
"//common:compiler_common",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.onlyElement;

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import dev.cel.bundle.Cel;
import dev.cel.common.CelAbstractSyntaxTree;
Expand Down Expand Up @@ -47,9 +48,22 @@
* calls and select statements with their evaluated result.
*/
public final class ConstantFoldingOptimizer implements CelAstOptimizer {
private static final int MAX_ITERATION_COUNT = 400;
private static final ConstantFoldingOptimizer INSTANCE =
new ConstantFoldingOptimizer(ConstantFoldingOptions.newBuilder().build());

public static final ConstantFoldingOptimizer INSTANCE = new ConstantFoldingOptimizer();
/** Returns a default instance of constant folding optimizer with preconfigured defaults. */
public static ConstantFoldingOptimizer getInstance() {
return INSTANCE;
}

/**
* Returns a new instance of constant folding optimizer configured with the provided {@link
* ConstantFoldingOptions}.
*/
public static ConstantFoldingOptimizer newInstance(
ConstantFoldingOptions constantFoldingOptions) {
return new ConstantFoldingOptimizer(constantFoldingOptions);
}

// Use optional.of and optional.none as sentinel function names for folding optional calls.
// TODO: Leverage CelValue representation of Optionals instead when available.
Expand All @@ -58,6 +72,7 @@ public final class ConstantFoldingOptimizer implements CelAstOptimizer {
private static final CelExpr OPTIONAL_NONE_EXPR =
CelExpr.ofCallExpr(0, Optional.empty(), OPTIONAL_NONE_FUNCTION, ImmutableList.of());

private final ConstantFoldingOptions constantFoldingOptions;
private final MutableAst mutableAst;

@Override
Expand All @@ -67,7 +82,7 @@ public CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel)
int iterCount = 0;
while (true) {
iterCount++;
if (iterCount == MAX_ITERATION_COUNT) {
if (iterCount >= constantFoldingOptions.maxIterationLimit()) {
throw new IllegalStateException("Max iteration count reached.");
}
Optional<CelExpr> foldableExpr =
Expand Down Expand Up @@ -553,7 +568,37 @@ private CelAbstractSyntaxTree pruneOptionalStructElements(
return ast;
}

private ConstantFoldingOptimizer() {
this.mutableAst = MutableAst.newInstance(MAX_ITERATION_COUNT);
/** Options to configure how Constant Folding behave. */
@AutoValue
public abstract static class ConstantFoldingOptions {
public abstract int maxIterationLimit();

/** Builder for configuring the {@link ConstantFoldingOptions}. */
@AutoValue.Builder
public abstract static class Builder {

/**
* Limit the number of iteration while performing constant folding. An exception is thrown if
* the iteration count exceeds the set value.
*/
public abstract Builder maxIterationLimit(int value);

public abstract ConstantFoldingOptions build();

Builder() {}
}

/** Returns a new options builder with recommended defaults pre-configured. */
public static Builder newBuilder() {
return new AutoValue_ConstantFoldingOptimizer_ConstantFoldingOptions.Builder()
.maxIterationLimit(400);
}

ConstantFoldingOptions() {}
}

private ConstantFoldingOptimizer(ConstantFoldingOptions constantFoldingOptions) {
this.constantFoldingOptions = constantFoldingOptions;
this.mutableAst = MutableAst.newInstance(constantFoldingOptions.maxIterationLimit());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import dev.cel.optimizer.CelOptimizationException;
import dev.cel.optimizer.CelOptimizer;
import dev.cel.optimizer.CelOptimizerFactory;
import dev.cel.optimizer.optimizers.ConstantFoldingOptimizer.ConstantFoldingOptions;
import dev.cel.parser.CelStandardMacro;
import dev.cel.parser.CelUnparser;
import dev.cel.parser.CelUnparserFactory;
Expand All @@ -53,7 +54,7 @@ public class ConstantFoldingOptimizerTest {

private static final CelOptimizer CEL_OPTIMIZER =
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
.addAstOptimizers(ConstantFoldingOptimizer.INSTANCE)
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
.build();

private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser();
Expand Down Expand Up @@ -211,7 +212,7 @@ public void constantFold_macros_macroCallMetadataPopulated(String source, String
.build();
CelOptimizer celOptimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(ConstantFoldingOptimizer.INSTANCE)
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
.build();
CelAbstractSyntaxTree ast = cel.compile(source).getAst();

Expand Down Expand Up @@ -253,7 +254,7 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E
.build();
CelOptimizer celOptimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(ConstantFoldingOptimizer.INSTANCE)
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
.build();
CelAbstractSyntaxTree ast = cel.compile(source).getAst();

Expand Down Expand Up @@ -299,7 +300,7 @@ public void constantFold_withMacroCallPopulated_comprehensionsAreReplacedWithNot
.build();
CelOptimizer celOptimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(ConstantFoldingOptimizer.INSTANCE)
.addAstOptimizers(ConstantFoldingOptimizer.getInstance())
.build();
CelAbstractSyntaxTree ast =
cel.compile("[1, 1 + 1, 1 + 1+ 1].map(i, i).filter(j, j % 2 == x)").getAst();
Expand Down Expand Up @@ -384,17 +385,19 @@ public void constantFold_astProducesConsistentlyNumberedIds() throws Exception {
public void iterationLimitReached_throws() throws Exception {
StringBuilder sb = new StringBuilder();
sb.append("0");
for (int i = 1; i < 400; i++) {
for (int i = 1; i < 200; i++) {
sb.append(" + ").append(i);
} // 0 + 1 + 2 + 3 + ... 400
} // 0 + 1 + 2 + 3 + ... 200
Cel cel =
CelFactory.standardCelBuilder()
.setOptions(CelOptions.current().maxParseRecursionDepth(400).build())
.setOptions(CelOptions.current().maxParseRecursionDepth(200).build())
.build();
CelAbstractSyntaxTree ast = cel.compile(sb.toString()).getAst();
CelOptimizer optimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
.addAstOptimizers(ConstantFoldingOptimizer.INSTANCE)
.addAstOptimizers(
ConstantFoldingOptimizer.newInstance(
ConstantFoldingOptions.newBuilder().maxIterationLimit(200).build()))
.build();

CelOptimizationException e =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ public void cse_applyConstFoldingAfter() throws Exception {
.addAstOptimizers(
SubexpressionOptimizer.newInstance(
SubexpressionOptimizerOptions.newBuilder().populateMacroCalls(true).build()),
ConstantFoldingOptimizer.INSTANCE)
ConstantFoldingOptimizer.getInstance())
.build();

CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast);
Expand Down

0 comments on commit 7d28e89

Please sign in to comment.