From aa0fb8cb55a2fb912d9e5f91e03eb57711b8325c Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Wed, 31 Jan 2024 15:58:31 -0800 Subject: [PATCH] Add serialization capability for tagged AST extensions PiperOrigin-RevId: 603185360 --- WORKSPACE | 6 +- .../common/CelProtoAbstractSyntaxTree.java | 75 +++++++++++ .../main/java/dev/cel/common/CelSource.java | 116 +++++++++++++++--- .../CelProtoAbstractSyntaxTreeTest.java | 10 ++ .../java/dev/cel/common/CelSourceTest.java | 24 ++++ 5 files changed, 212 insertions(+), 19 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index dd6291bf..351e8221 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -138,10 +138,10 @@ http_archive( # cel-spec api/expr canonical protos http_archive( name = "cel_spec", - sha256 = "6c2d9ec6dd5e2afbc41423dcaeca9fdd73edcd76554a897caee5b9a2b0e20491", - strip_prefix = "cel-spec-0.11.0", + sha256 = "3579c97b13548714f9059ef6f30c5264d439efef4b438e76e7180709efd93a6b", + strip_prefix = "cel-spec-0.14.0", urls = [ - "https://github.com/google/cel-spec/archive/refs/tags/v0.11.0.tar.gz", + "https://github.com/google/cel-spec/archive/refs/tags/v0.14.0.tar.gz", ], ) diff --git a/common/src/main/java/dev/cel/common/CelProtoAbstractSyntaxTree.java b/common/src/main/java/dev/cel/common/CelProtoAbstractSyntaxTree.java index fa0719e8..ef69c174 100644 --- a/common/src/main/java/dev/cel/common/CelProtoAbstractSyntaxTree.java +++ b/common/src/main/java/dev/cel/common/CelProtoAbstractSyntaxTree.java @@ -15,16 +15,22 @@ package dev.cel.common; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import dev.cel.expr.CheckedExpr; import dev.cel.expr.Expr; import dev.cel.expr.ParsedExpr; import dev.cel.expr.SourceInfo; +import dev.cel.expr.SourceInfo.Extension; +import dev.cel.expr.SourceInfo.Extension.Component; +import dev.cel.expr.SourceInfo.Extension.Version; import dev.cel.expr.Type; +import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CheckReturnValue; import dev.cel.common.ast.CelExprConverter; import dev.cel.common.types.CelTypes; +import java.util.Collection; import java.util.Map.Entry; /** @@ -49,6 +55,9 @@ private CelProtoAbstractSyntaxTree(CheckedExpr checkedExpr) { .addAllMacroCalls( CelExprConverter.exprMacroCallsToCelExprMacroCalls( checkedExpr.getSourceInfo().getMacroCallsMap())) + .addAllExtensions( + fromExprExtensionsToCelExtensions( + checkedExpr.getSourceInfo().getExtensionsList())) .setDescription(checkedExpr.getSourceInfo().getLocation()) .build(), checkedExpr.getReferenceMapMap().entrySet().stream() @@ -69,6 +78,8 @@ private CelProtoAbstractSyntaxTree(CelAbstractSyntaxTree ast) { SourceInfo.newBuilder() .setLocation(ast.getSource().getDescription()) .addAllLineOffsets(ast.getSource().getLineOffsets()) + .addAllExtensions( + fromCelExtensionsToExprExtensions(ast.getSource().getExtensions())) .putAllMacroCalls( ast.getSource().getMacroCalls().entrySet().stream() .collect( @@ -173,5 +184,69 @@ public ParsedExpr toParsedExpr() { public Type getProtoResultType() { return CelTypes.celTypeToType(ast.getResultType()); } + + private static ImmutableList fromCelExtensionsToExprExtensions( + Collection extensions) { + return extensions.stream() + .map( + celSourceExtension -> + Extension.newBuilder() + .setId(celSourceExtension.id()) + .setVersion( + Version.newBuilder() + .setMajor(celSourceExtension.version().major()) + .setMinor(celSourceExtension.version().minor())) + .addAllAffectedComponents( + celSourceExtension.affectedComponents().stream() + .map( + component -> { + switch (component) { + case COMPONENT_UNSPECIFIED: + return Component.COMPONENT_UNSPECIFIED; + case COMPONENT_PARSER: + return Component.COMPONENT_PARSER; + case COMPONENT_TYPE_CHECKER: + return Component.COMPONENT_TYPE_CHECKER; + case COMPONENT_RUNTIME: + return Component.COMPONENT_RUNTIME; + } + throw new AssertionError( + "Unexpected component kind: " + component); + }) + .collect(toImmutableList())) + .build()) + .collect(toImmutableList()); + } + + private static ImmutableList fromExprExtensionsToCelExtensions( + Collection extensions) { + return extensions.stream() + .map( + exprExtension -> + CelSource.Extension.create( + exprExtension.getId(), + CelSource.Extension.Version.of( + exprExtension.getVersion().getMajor(), + exprExtension.getVersion().getMinor()), + exprExtension.getAffectedComponentsList().stream() + .map( + component -> { + switch (component) { + case COMPONENT_UNSPECIFIED: + return CelSource.Extension.Component.COMPONENT_UNSPECIFIED; + case COMPONENT_PARSER: + return CelSource.Extension.Component.COMPONENT_PARSER; + case COMPONENT_TYPE_CHECKER: + return CelSource.Extension.Component.COMPONENT_TYPE_CHECKER; + case COMPONENT_RUNTIME: + return CelSource.Extension.Component.COMPONENT_RUNTIME; + case UNRECOGNIZED: + // fall-through + } + throw new AssertionError("Unexpected component kind: " + component); + }) + .collect(toImmutableList()))) + .collect(toImmutableList()); + } } // LINT.ThenChange(CelProtoV1Alpha1AbstractSyntaxTree.java) diff --git a/common/src/main/java/dev/cel/common/CelSource.java b/common/src/main/java/dev/cel/common/CelSource.java index 88ab19c6..3f4c0d2d 100644 --- a/common/src/main/java/dev/cel/common/CelSource.java +++ b/common/src/main/java/dev/cel/common/CelSource.java @@ -17,6 +17,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.auto.value.AutoValue; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -26,6 +27,7 @@ import dev.cel.common.ast.CelExpr; import dev.cel.common.internal.CelCodePointArray; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -34,7 +36,6 @@ /** Represents the source content of an expression and related metadata. */ @Immutable public final class CelSource { - private static final Splitter LINE_SPLITTER = Splitter.on('\n'); private final CelCodePointArray codePoints; @@ -42,13 +43,15 @@ public final class CelSource { private final ImmutableList lineOffsets; private final ImmutableMap positions; private final ImmutableMap macroCalls; + private final ImmutableList extensions; private CelSource(Builder builder) { - codePoints = checkNotNull(builder.codePoints); - description = checkNotNull(builder.description); - positions = checkNotNull(builder.positions.buildOrThrow()); - lineOffsets = checkNotNull(ImmutableList.copyOf(builder.lineOffsets)); - macroCalls = checkNotNull(ImmutableMap.copyOf(builder.macroCalls)); + this.codePoints = checkNotNull(builder.codePoints); + this.description = checkNotNull(builder.description); + this.positions = checkNotNull(builder.positions.buildOrThrow()); + this.lineOffsets = checkNotNull(ImmutableList.copyOf(builder.lineOffsets)); + this.macroCalls = checkNotNull(ImmutableMap.copyOf(builder.macroCalls)); + this.extensions = checkNotNull(builder.extensions.build()); } public CelCodePointArray getContent() { @@ -77,6 +80,10 @@ public ImmutableMap getMacroCalls() { return macroCalls; } + public ImmutableList getExtensions() { + return extensions; + } + /** See {@link #getLocationOffset(int, int)}. */ public Optional getLocationOffset(CelSourceLocation location) { checkNotNull(location); @@ -201,6 +208,7 @@ public static final class Builder { private final List lineOffsets; private final ImmutableMap.Builder positions; private final Map macroCalls; + private final ImmutableList.Builder extensions; private String description; @@ -213,7 +221,8 @@ private Builder(CelCodePointArray codePoints, List lineOffsets) { this.lineOffsets = checkNotNull(lineOffsets); this.positions = ImmutableMap.builder(); this.macroCalls = new HashMap<>(); - description = ""; + this.extensions = ImmutableList.builder(); + this.description = ""; } @CanIgnoreReturnValue @@ -229,15 +238,6 @@ public Builder addLineOffsets(int lineOffset) { return this; } - @CanIgnoreReturnValue - public Builder addLineOffsets(int... lineOffsets) { - // Purposefully not using Arrays.asList to avoid int boxing/unboxing. - for (int index = 0; index != lineOffsets.length; index++) { - addLineOffsets(lineOffsets[index]); - } - return this; - } - @CanIgnoreReturnValue public Builder addAllLineOffsets(Iterable lineOffsets) { for (int lineOffset : lineOffsets) { @@ -277,6 +277,18 @@ public Builder clearMacroCall(long exprId) { return this; } + @CanIgnoreReturnValue + public Builder addAllExtensions(Iterable extensions) { + checkNotNull(extensions); + this.extensions.addAll(extensions); + return this; + } + + @CanIgnoreReturnValue + public Builder addAllExtensions(Extension... extensions) { + return addAllExtensions(Arrays.asList(extensions)); + } + /** See {@link #getLocationOffset(int, int)}. */ public Optional getLocationOffset(CelSourceLocation location) { checkNotNull(location); @@ -333,4 +345,76 @@ private LineAndOffset(int line, int offset) { int line; int offset; } + + /** + * Tag for an extension that were used while parsing or type checking the source expression. For + * example, optimizations that require special runtime support may be specified. These are used to + * check feature support between components in separate implementations. This can be used to + * either skip redundant work or report an error if the extension is unsupported. + */ + @AutoValue + @Immutable + abstract static class Extension { + + /** Identifier for the extension. Example: constant_folding */ + abstract String id(); + + /** + * Version info. May be skipped if it isn't meaningful for the extension. (for example + * constant_folding might always be v0.0). + */ + abstract Version version(); + + /** + * If set, the listed components must understand the extension for the expression to evaluate + * correctly. + */ + abstract ImmutableList affectedComponents(); + + @AutoValue + @Immutable + abstract static class Version { + + /** + * Major version changes indicate different required support level from the required + * components. + */ + abstract long major(); + + /** + * Minor version changes must not change the observed behavior from existing implementations, + * but may be provided informational. + */ + abstract long minor(); + + /** Create a new instance of Version with the provided major and minor values. */ + static Version of(long major, long minor) { + return new AutoValue_CelSource_Extension_Version(major, minor); + } + } + + /** CEL component specifier. */ + enum Component { + /** Unspecified, default. */ + COMPONENT_UNSPECIFIED, + /** Parser. Converts a CEL string to an AST. */ + COMPONENT_PARSER, + /** Type checker. Checks that references in an AST are defined and types agree. */ + COMPONENT_TYPE_CHECKER, + /** Runtime. Evaluates a parsed and optionally checked CEL AST against a context. */ + COMPONENT_RUNTIME; + } + + @CheckReturnValue + static Extension create(String id, Version version, Iterable components) { + checkNotNull(version); + checkNotNull(components); + return new AutoValue_CelSource_Extension(id, version, ImmutableList.copyOf(components)); + } + + @CheckReturnValue + static Extension create(String id, Version version, Component... components) { + return create(id, version, Arrays.asList(components)); + } + } } diff --git a/common/src/test/java/dev/cel/common/CelProtoAbstractSyntaxTreeTest.java b/common/src/test/java/dev/cel/common/CelProtoAbstractSyntaxTreeTest.java index 8ab329f6..1d6170f5 100644 --- a/common/src/test/java/dev/cel/common/CelProtoAbstractSyntaxTreeTest.java +++ b/common/src/test/java/dev/cel/common/CelProtoAbstractSyntaxTreeTest.java @@ -25,6 +25,9 @@ import dev.cel.expr.ParsedExpr; import dev.cel.expr.Reference; import dev.cel.expr.SourceInfo; +import dev.cel.expr.SourceInfo.Extension; +import dev.cel.expr.SourceInfo.Extension.Component; +import dev.cel.expr.SourceInfo.Extension.Version; import dev.cel.common.types.CelTypes; import java.util.Arrays; import org.junit.Test; @@ -57,6 +60,13 @@ public class CelProtoAbstractSyntaxTreeTest { .setLocation("test/location.cel") .putPositions(1L, 0) .addLineOffsets(4) + .addExtensions( + Extension.newBuilder() + .setId("extension_id") + .addAffectedComponents(Component.COMPONENT_PARSER) + .addAffectedComponents(Component.COMPONENT_TYPE_CHECKER) + .addAffectedComponents(Component.COMPONENT_RUNTIME) + .setVersion(Version.newBuilder().setMajor(5).setMinor(3))) .putMacroCalls( 2, Expr.newBuilder() diff --git a/common/src/test/java/dev/cel/common/CelSourceTest.java b/common/src/test/java/dev/cel/common/CelSourceTest.java index e7b2aaf0..e77db1da 100644 --- a/common/src/test/java/dev/cel/common/CelSourceTest.java +++ b/common/src/test/java/dev/cel/common/CelSourceTest.java @@ -19,6 +19,9 @@ import static org.junit.Assert.assertThrows; import com.google.common.truth.Truth8; +import dev.cel.common.CelSource.Extension; +import dev.cel.common.CelSource.Extension.Component; +import dev.cel.common.CelSource.Extension.Version; import dev.cel.common.internal.BasicCodePointArray; import dev.cel.common.internal.CodePointStream; import dev.cel.common.internal.Latin1CodePointArray; @@ -157,4 +160,25 @@ public void fromString_handlesMultiLineSupplemental() throws Exception { assertThat(charStream.LA(-1)).isEqualTo(IntStream.EOF); assertThat(source.getLineOffsets()).containsExactly(6, 13).inOrder(); } + + @Test + public void source_withExtension() { + CelSource celSource = + CelSource.newBuilder("") + .addAllExtensions( + Extension.create( + "extension_id", + Version.of(5, 3), + Component.COMPONENT_PARSER, + Component.COMPONENT_TYPE_CHECKER)) + .build(); + + Extension extension = celSource.getExtensions().get(0); + assertThat(extension.id()).isEqualTo("extension_id"); + assertThat(extension.version().major()).isEqualTo(5L); + assertThat(extension.version().minor()).isEqualTo(3L); + assertThat(extension.affectedComponents()) + .containsExactly(Component.COMPONENT_PARSER, Component.COMPONENT_TYPE_CHECKER); + assertThat(celSource.getExtensions()).hasSize(1); + } }