Skip to content

Commit

Permalink
Tag AST containing cel.block call
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607499642
  • Loading branch information
l46kok authored and copybara-github committed Feb 16, 2024
1 parent 70ef6f9 commit 496ab08
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 10 deletions.
14 changes: 8 additions & 6 deletions common/src/main/java/dev/cel/common/CelSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ public Builder toBuilder() {
return new Builder(codePoints, lineOffsets)
.setDescription(description)
.addPositionsMap(positions)
.addAllExtensions(extensions)
.addAllMacroCalls(macroCalls);
}

Expand Down Expand Up @@ -354,7 +355,7 @@ private LineAndOffset(int line, int offset) {
*/
@AutoValue
@Immutable
abstract static class Extension {
public abstract static class Extension {

/** Identifier for the extension. Example: constant_folding */
abstract String id();
Expand All @@ -371,9 +372,10 @@ abstract static class Extension {
*/
abstract ImmutableList<Component> affectedComponents();

/** Version of the extension */
@AutoValue
@Immutable
abstract static class Version {
public abstract static class Version {

/**
* Major version changes indicate different required support level from the required
Expand All @@ -388,13 +390,13 @@ abstract static class Version {
abstract long minor();

/** Create a new instance of Version with the provided major and minor values. */
static Version of(long major, long minor) {
public static Version of(long major, long minor) {
return new AutoValue_CelSource_Extension_Version(major, minor);
}
}

/** CEL component specifier. */
enum Component {
public enum Component {
/** Unspecified, default. */
COMPONENT_UNSPECIFIED,
/** Parser. Converts a CEL string to an AST. */
Expand All @@ -406,14 +408,14 @@ enum Component {
}

@CheckReturnValue
static Extension create(String id, Version version, Iterable<Component> components) {
public static Extension create(String id, Version version, Iterable<Component> components) {
checkNotNull(version);
checkNotNull(components);
return new AutoValue_CelSource_Extension(id, version, ImmutableList.copyOf(components));
}

@CheckReturnValue
static Extension create(String id, Version version, Component... components) {
public static Extension create(String id, Version version, Component... components) {
return create(id, version, Arrays.asList(components));
}
}
Expand Down
5 changes: 3 additions & 2 deletions optimizer/src/main/java/dev/cel/optimizer/MutableAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,12 @@ private CelSource normalizeMacroSource(
ExprIdGenerator idGenerator) {
// Remove the macro metadata that no longer exists in the AST due to being replaced.
celSource = celSource.toBuilder().clearMacroCall(exprIdToReplace).build();
CelSource.Builder sourceBuilder =
CelSource.newBuilder().addAllExtensions(celSource.getExtensions());
if (celSource.getMacroCalls().isEmpty()) {
return CelSource.newBuilder().build();
return sourceBuilder.build();
}

CelSource.Builder sourceBuilder = CelSource.newBuilder();
ImmutableMap<Long, CelExpr> allExprs =
CelNavigableExpr.fromExpr(mutatedRoot.build())
.allNodes()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelSource;
import dev.cel.common.CelSource.Extension;
import dev.cel.common.CelSource.Extension.Component;
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.CelValidationException;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
Expand Down Expand Up @@ -87,6 +90,10 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
stream(Operator.values()).map(Operator::getFunction),
stream(Standard.Function.values()).map(Standard.Function::getFunction))
.collect(toImmutableSet());

private static final Extension CEL_BLOCK_AST_EXTENSION_TAG =
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME);

private final SubexpressionOptimizerOptions cseOptions;
private final MutableAst mutableAst;

Expand Down Expand Up @@ -209,7 +216,16 @@ private CelAbstractSyntaxTree optimizeUsingCelBlock(

// Restore the expected result type the environment had prior to optimization.
celBuilder.setResultType(resultType);
return astToModify;

return tagAstExtension(astToModify);
}

private static CelAbstractSyntaxTree tagAstExtension(CelAbstractSyntaxTree ast) {
// Tag the extension
CelSource.Builder celSourceBuilder =
ast.getSource().toBuilder().addAllExtensions(CEL_BLOCK_AST_EXTENSION_TAG);

return CelAbstractSyntaxTree.newParsedAst(ast.getExpr(), celSourceBuilder.build());
}

/**
Expand Down Expand Up @@ -510,7 +526,9 @@ public abstract static class Builder {

/**
* Rewrites the optimized AST using cel.@block call instead of cascaded cel.bind macros, aimed
* to produce a more compact AST.
* to produce a more compact AST. {@link com.google.api.expr.SourceInfo.Extension} field will
* be populated in the AST to inform that special runtime support is required to evaluate the
* optimized expression.
*/
public abstract Builder enableCelBlock(boolean value);

Expand Down
21 changes: 21 additions & 0 deletions optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelSource;
import dev.cel.common.CelSource.Extension;
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr;
import dev.cel.common.ast.CelExpr.CelCall;
Expand Down Expand Up @@ -99,6 +102,7 @@ public void mutableAst_nonMacro_sourceCleared() throws Exception {
assertThat(mutatedAst.getSource().getDescription()).isEmpty();
assertThat(mutatedAst.getSource().getLineOffsets()).isEmpty();
assertThat(mutatedAst.getSource().getPositionsMap()).isEmpty();
assertThat(mutatedAst.getSource().getExtensions()).isEmpty();
assertThat(mutatedAst.getSource().getMacroCalls()).isEmpty();
}

Expand All @@ -113,9 +117,26 @@ public void mutableAst_macro_sourceMacroCallsPopulated() throws Exception {
assertThat(mutatedAst.getSource().getDescription()).isEmpty();
assertThat(mutatedAst.getSource().getLineOffsets()).isEmpty();
assertThat(mutatedAst.getSource().getPositionsMap()).isEmpty();
assertThat(mutatedAst.getSource().getExtensions()).isEmpty();
assertThat(mutatedAst.getSource().getMacroCalls()).isNotEmpty();
}

@Test
public void mutableAst_astContainsTaggedExtension_retained() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("has(TestAllTypes{}.single_int32)").getAst();
Extension extension = Extension.create("test", Version.of(1, 1));
CelSource celSource = ast.getSource().toBuilder().addAllExtensions(extension).build();
ast =
CelAbstractSyntaxTree.newCheckedAst(
ast.getExpr(), celSource, ast.getReferenceMap(), ast.getTypeMap());

CelAbstractSyntaxTree mutatedAst =
MUTABLE_AST.replaceSubtree(
ast, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1);

assertThat(mutatedAst.getSource().getExtensions()).containsExactly(extension);
}

@Test
@TestParameters("{source: '[1].exists(x, x > 0)', expectedMacroCallSize: 1}")
@TestParameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOptions;
import dev.cel.common.CelOverloadDecl;
import dev.cel.common.CelSource.Extension;
import dev.cel.common.CelSource.Extension.Component;
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.CelValidationException;
import dev.cel.common.CelVarDecl;
import dev.cel.common.ast.CelConstant;
Expand Down Expand Up @@ -1109,6 +1112,27 @@ public void iterationLimitReached_throws(boolean enableCelBlock) throws Exceptio
assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Max iteration count reached.");
}

@Test
public void celBlock_astExtensionTagged() throws Exception {
CelAbstractSyntaxTree ast = CEL.compile("size(x) + size(x)").getAst();
CelOptimizer optimizer =
CelOptimizerFactory.standardCelOptimizerBuilder(CEL)
.addAstOptimizers(
SubexpressionOptimizer.newInstance(
SubexpressionOptimizerOptions.newBuilder()
.populateMacroCalls(true)
.enableCelBlock(true)
.build()),
ConstantFoldingOptimizer.getInstance())
.build();

CelAbstractSyntaxTree optimizedAst = optimizer.optimize(ast);

assertThat(optimizedAst.getSource().getExtensions())
.containsExactly(
Extension.create("cel_block", Version.of(1L, 1L), Component.COMPONENT_RUNTIME));
}

private enum BlockTestCase {
BOOL_LITERAL("cel.block([true, false], index0 || index1)"),
STRING_CONCAT("cel.block(['a' + 'b', index0 + 'c'], index1 + 'd') == 'abcd'"),
Expand Down

0 comments on commit 496ab08

Please sign in to comment.