Skip to content

Commit

Permalink
Add serialization capability for tagged AST extensions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 603185360
  • Loading branch information
l46kok authored and copybara-github committed Jan 31, 2024
1 parent 9ae535c commit aa0fb8c
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 19 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,10 @@ http_archive(
# cel-spec api/expr canonical protos
http_archive(
name = "cel_spec",
sha256 = "6c2d9ec6dd5e2afbc41423dcaeca9fdd73edcd76554a897caee5b9a2b0e20491",
strip_prefix = "cel-spec-0.11.0",
sha256 = "3579c97b13548714f9059ef6f30c5264d439efef4b438e76e7180709efd93a6b",
strip_prefix = "cel-spec-0.14.0",
urls = [
"https://github.com/google/cel-spec/archive/refs/tags/v0.11.0.tar.gz",
"https://github.com/google/cel-spec/archive/refs/tags/v0.14.0.tar.gz",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,22 @@
package dev.cel.common;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

import dev.cel.expr.CheckedExpr;
import dev.cel.expr.Expr;
import dev.cel.expr.ParsedExpr;
import dev.cel.expr.SourceInfo;
import dev.cel.expr.SourceInfo.Extension;
import dev.cel.expr.SourceInfo.Extension.Component;
import dev.cel.expr.SourceInfo.Extension.Version;
import dev.cel.expr.Type;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CheckReturnValue;
import dev.cel.common.ast.CelExprConverter;
import dev.cel.common.types.CelTypes;
import java.util.Collection;
import java.util.Map.Entry;

/**
Expand All @@ -49,6 +55,9 @@ private CelProtoAbstractSyntaxTree(CheckedExpr checkedExpr) {
.addAllMacroCalls(
CelExprConverter.exprMacroCallsToCelExprMacroCalls(
checkedExpr.getSourceInfo().getMacroCallsMap()))
.addAllExtensions(
fromExprExtensionsToCelExtensions(
checkedExpr.getSourceInfo().getExtensionsList()))
.setDescription(checkedExpr.getSourceInfo().getLocation())
.build(),
checkedExpr.getReferenceMapMap().entrySet().stream()
Expand All @@ -69,6 +78,8 @@ private CelProtoAbstractSyntaxTree(CelAbstractSyntaxTree ast) {
SourceInfo.newBuilder()
.setLocation(ast.getSource().getDescription())
.addAllLineOffsets(ast.getSource().getLineOffsets())
.addAllExtensions(
fromCelExtensionsToExprExtensions(ast.getSource().getExtensions()))
.putAllMacroCalls(
ast.getSource().getMacroCalls().entrySet().stream()
.collect(
Expand Down Expand Up @@ -173,5 +184,69 @@ public ParsedExpr toParsedExpr() {
public Type getProtoResultType() {
return CelTypes.celTypeToType(ast.getResultType());
}

private static ImmutableList<Extension> fromCelExtensionsToExprExtensions(
Collection<CelSource.Extension> extensions) {
return extensions.stream()
.map(
celSourceExtension ->
Extension.newBuilder()
.setId(celSourceExtension.id())
.setVersion(
Version.newBuilder()
.setMajor(celSourceExtension.version().major())
.setMinor(celSourceExtension.version().minor()))
.addAllAffectedComponents(
celSourceExtension.affectedComponents().stream()
.map(
component -> {
switch (component) {
case COMPONENT_UNSPECIFIED:
return Component.COMPONENT_UNSPECIFIED;
case COMPONENT_PARSER:
return Component.COMPONENT_PARSER;
case COMPONENT_TYPE_CHECKER:
return Component.COMPONENT_TYPE_CHECKER;
case COMPONENT_RUNTIME:
return Component.COMPONENT_RUNTIME;
}
throw new AssertionError(
"Unexpected component kind: " + component);
})
.collect(toImmutableList()))
.build())
.collect(toImmutableList());
}

private static ImmutableList<CelSource.Extension> fromExprExtensionsToCelExtensions(
Collection<Extension> extensions) {
return extensions.stream()
.map(
exprExtension ->
CelSource.Extension.create(
exprExtension.getId(),
CelSource.Extension.Version.of(
exprExtension.getVersion().getMajor(),
exprExtension.getVersion().getMinor()),
exprExtension.getAffectedComponentsList().stream()
.map(
component -> {
switch (component) {
case COMPONENT_UNSPECIFIED:
return CelSource.Extension.Component.COMPONENT_UNSPECIFIED;
case COMPONENT_PARSER:
return CelSource.Extension.Component.COMPONENT_PARSER;
case COMPONENT_TYPE_CHECKER:
return CelSource.Extension.Component.COMPONENT_TYPE_CHECKER;
case COMPONENT_RUNTIME:
return CelSource.Extension.Component.COMPONENT_RUNTIME;
case UNRECOGNIZED:
// fall-through
}
throw new AssertionError("Unexpected component kind: " + component);
})
.collect(toImmutableList())))
.collect(toImmutableList());
}
}
// LINT.ThenChange(CelProtoV1Alpha1AbstractSyntaxTree.java)
116 changes: 100 additions & 16 deletions common/src/main/java/dev/cel/common/CelSource.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.auto.value.AutoValue;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -26,6 +27,7 @@
import dev.cel.common.ast.CelExpr;
import dev.cel.common.internal.CelCodePointArray;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -34,21 +36,22 @@
/** Represents the source content of an expression and related metadata. */
@Immutable
public final class CelSource {

private static final Splitter LINE_SPLITTER = Splitter.on('\n');

private final CelCodePointArray codePoints;
private final String description;
private final ImmutableList<Integer> lineOffsets;
private final ImmutableMap<Long, Integer> positions;
private final ImmutableMap<Long, CelExpr> macroCalls;
private final ImmutableList<Extension> extensions;

private CelSource(Builder builder) {
codePoints = checkNotNull(builder.codePoints);
description = checkNotNull(builder.description);
positions = checkNotNull(builder.positions.buildOrThrow());
lineOffsets = checkNotNull(ImmutableList.copyOf(builder.lineOffsets));
macroCalls = checkNotNull(ImmutableMap.copyOf(builder.macroCalls));
this.codePoints = checkNotNull(builder.codePoints);
this.description = checkNotNull(builder.description);
this.positions = checkNotNull(builder.positions.buildOrThrow());
this.lineOffsets = checkNotNull(ImmutableList.copyOf(builder.lineOffsets));
this.macroCalls = checkNotNull(ImmutableMap.copyOf(builder.macroCalls));
this.extensions = checkNotNull(builder.extensions.build());
}

public CelCodePointArray getContent() {
Expand Down Expand Up @@ -77,6 +80,10 @@ public ImmutableMap<Long, CelExpr> getMacroCalls() {
return macroCalls;
}

public ImmutableList<Extension> getExtensions() {
return extensions;
}

/** See {@link #getLocationOffset(int, int)}. */
public Optional<Integer> getLocationOffset(CelSourceLocation location) {
checkNotNull(location);
Expand Down Expand Up @@ -201,6 +208,7 @@ public static final class Builder {
private final List<Integer> lineOffsets;
private final ImmutableMap.Builder<Long, Integer> positions;
private final Map<Long, CelExpr> macroCalls;
private final ImmutableList.Builder<Extension> extensions;

private String description;

Expand All @@ -213,7 +221,8 @@ private Builder(CelCodePointArray codePoints, List<Integer> lineOffsets) {
this.lineOffsets = checkNotNull(lineOffsets);
this.positions = ImmutableMap.builder();
this.macroCalls = new HashMap<>();
description = "";
this.extensions = ImmutableList.builder();
this.description = "";
}

@CanIgnoreReturnValue
Expand All @@ -229,15 +238,6 @@ public Builder addLineOffsets(int lineOffset) {
return this;
}

@CanIgnoreReturnValue
public Builder addLineOffsets(int... lineOffsets) {
// Purposefully not using Arrays.asList to avoid int boxing/unboxing.
for (int index = 0; index != lineOffsets.length; index++) {
addLineOffsets(lineOffsets[index]);
}
return this;
}

@CanIgnoreReturnValue
public Builder addAllLineOffsets(Iterable<Integer> lineOffsets) {
for (int lineOffset : lineOffsets) {
Expand Down Expand Up @@ -277,6 +277,18 @@ public Builder clearMacroCall(long exprId) {
return this;
}

@CanIgnoreReturnValue
public Builder addAllExtensions(Iterable<? extends Extension> extensions) {
checkNotNull(extensions);
this.extensions.addAll(extensions);
return this;
}

@CanIgnoreReturnValue
public Builder addAllExtensions(Extension... extensions) {
return addAllExtensions(Arrays.asList(extensions));
}

/** See {@link #getLocationOffset(int, int)}. */
public Optional<Integer> getLocationOffset(CelSourceLocation location) {
checkNotNull(location);
Expand Down Expand Up @@ -333,4 +345,76 @@ private LineAndOffset(int line, int offset) {
int line;
int offset;
}

/**
* Tag for an extension that were used while parsing or type checking the source expression. For
* example, optimizations that require special runtime support may be specified. These are used to
* check feature support between components in separate implementations. This can be used to
* either skip redundant work or report an error if the extension is unsupported.
*/
@AutoValue
@Immutable
abstract static class Extension {

/** Identifier for the extension. Example: constant_folding */
abstract String id();

/**
* Version info. May be skipped if it isn't meaningful for the extension. (for example
* constant_folding might always be v0.0).
*/
abstract Version version();

/**
* If set, the listed components must understand the extension for the expression to evaluate
* correctly.
*/
abstract ImmutableList<Component> affectedComponents();

@AutoValue
@Immutable
abstract static class Version {

/**
* Major version changes indicate different required support level from the required
* components.
*/
abstract long major();

/**
* Minor version changes must not change the observed behavior from existing implementations,
* but may be provided informational.
*/
abstract long minor();

/** Create a new instance of Version with the provided major and minor values. */
static Version of(long major, long minor) {
return new AutoValue_CelSource_Extension_Version(major, minor);
}
}

/** CEL component specifier. */
enum Component {
/** Unspecified, default. */
COMPONENT_UNSPECIFIED,
/** Parser. Converts a CEL string to an AST. */
COMPONENT_PARSER,
/** Type checker. Checks that references in an AST are defined and types agree. */
COMPONENT_TYPE_CHECKER,
/** Runtime. Evaluates a parsed and optionally checked CEL AST against a context. */
COMPONENT_RUNTIME;
}

@CheckReturnValue
static Extension create(String id, Version version, Iterable<Component> components) {
checkNotNull(version);
checkNotNull(components);
return new AutoValue_CelSource_Extension(id, version, ImmutableList.copyOf(components));
}

@CheckReturnValue
static Extension create(String id, Version version, Component... components) {
return create(id, version, Arrays.asList(components));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import dev.cel.expr.ParsedExpr;
import dev.cel.expr.Reference;
import dev.cel.expr.SourceInfo;
import dev.cel.expr.SourceInfo.Extension;
import dev.cel.expr.SourceInfo.Extension.Component;
import dev.cel.expr.SourceInfo.Extension.Version;
import dev.cel.common.types.CelTypes;
import java.util.Arrays;
import org.junit.Test;
Expand Down Expand Up @@ -57,6 +60,13 @@ public class CelProtoAbstractSyntaxTreeTest {
.setLocation("test/location.cel")
.putPositions(1L, 0)
.addLineOffsets(4)
.addExtensions(
Extension.newBuilder()
.setId("extension_id")
.addAffectedComponents(Component.COMPONENT_PARSER)
.addAffectedComponents(Component.COMPONENT_TYPE_CHECKER)
.addAffectedComponents(Component.COMPONENT_RUNTIME)
.setVersion(Version.newBuilder().setMajor(5).setMinor(3)))
.putMacroCalls(
2,
Expr.newBuilder()
Expand Down
24 changes: 24 additions & 0 deletions common/src/test/java/dev/cel/common/CelSourceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import static org.junit.Assert.assertThrows;

import com.google.common.truth.Truth8;
import dev.cel.common.CelSource.Extension;
import dev.cel.common.CelSource.Extension.Component;
import dev.cel.common.CelSource.Extension.Version;
import dev.cel.common.internal.BasicCodePointArray;
import dev.cel.common.internal.CodePointStream;
import dev.cel.common.internal.Latin1CodePointArray;
Expand Down Expand Up @@ -157,4 +160,25 @@ public void fromString_handlesMultiLineSupplemental() throws Exception {
assertThat(charStream.LA(-1)).isEqualTo(IntStream.EOF);
assertThat(source.getLineOffsets()).containsExactly(6, 13).inOrder();
}

@Test
public void source_withExtension() {
CelSource celSource =
CelSource.newBuilder("")
.addAllExtensions(
Extension.create(
"extension_id",
Version.of(5, 3),
Component.COMPONENT_PARSER,
Component.COMPONENT_TYPE_CHECKER))
.build();

Extension extension = celSource.getExtensions().get(0);
assertThat(extension.id()).isEqualTo("extension_id");
assertThat(extension.version().major()).isEqualTo(5L);
assertThat(extension.version().minor()).isEqualTo(3L);
assertThat(extension.affectedComponents())
.containsExactly(Component.COMPONENT_PARSER, Component.COMPONENT_TYPE_CHECKER);
assertThat(celSource.getExtensions()).hasSize(1);
}
}

0 comments on commit aa0fb8c

Please sign in to comment.