diff --git a/README.md b/README.md index dab60e79..322e11d0 100644 --- a/README.md +++ b/README.md @@ -69,10 +69,10 @@ if len(errors.GetErrors()) != 0 { // package (empty string): typeProvider := types.NewProvider() env := checker.NewStandardEnv(packages.DefaultPackage, typeProvider, errors) -env.Add(decls.NewIdent("a", decls.Bool), - decls.NewIdent("b", decls.Bool), - decls.NewIdent("c", decls.NewListType(decls.Int))) -c := checker.Check(p, env, "") +env.Add(decls.NewIdent("a", decls.Bool, nil), + decls.NewIdent("b", decls.Bool, nil), + decls.NewIdent("c", decls.NewListType(decls.Int), nil)) +c := checker.Check(p, env, "") if len(errors.GetErrors()) != 0 { return nil, fmt.Error(errors.ToDisplayString())) } @@ -106,4 +106,4 @@ Disclaimer: This is not an official Google product. [3]: https://github.com/google/cel-cpp [4]: https://github.com/google/cel-go/issues [5]: https://bazel.build -[6]: https://godoc.org/github.com/google/cel-go \ No newline at end of file +[6]: https://godoc.org/github.com/google/cel-go diff --git a/WORKSPACE b/WORKSPACE index 0f432a08..6eec4d87 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -18,6 +18,50 @@ go_repository( git_repository( name = "com_google_cel_spec", - commit = "3769a0b59441e6a2ebe747154dd1a4c85ce65ae0", + commit = "3c25c4d4ffb504e2c24c1d84196c78ba3ac9e612", remote = "https://github.com/google/cel-spec.git", ) + +new_http_archive( + name = "com_google_googleapis", + url = "https://github.com/googleapis/googleapis/archive/common-protos-1_3_1.zip", + strip_prefix = "googleapis-common-protos-1_3_1/", + build_file_content = """ +load('@io_bazel_rules_go//proto:def.bzl', 'go_proto_library') + +proto_library( + name = 'rpc_status', + srcs = ['google/rpc/status.proto'], + deps = [ + '@com_google_protobuf//:any_proto', + '@com_google_protobuf//:empty_proto' + ], + visibility = ['//visibility:public'], +) + +go_proto_library( + name = 'rpc_status_go_proto', + # TODO: Switch to the correct import path when bazel rules fixed. + #importpath = 'google.golang.org/genproto/googleapis/rpc/status', + importpath = 'github.com/googleapis/googleapis/google/rpc', + proto = ':rpc_status', + visibility = ['//visibility:public'], +) +""" +) + +git_repository( + name = "org_pubref_rules_protobuf", + remote = "https://github.com/pubref/rules_protobuf", + tag = "v0.8.2", +) + +load("@org_pubref_rules_protobuf//go:rules.bzl", "go_proto_repositories") +go_proto_repositories() + +go_repository( + name = "org_golang_google_grpc", + importpath = "google.golang.org/grpc", + tag = "v1.11.3", + remote = "https://github.com/grpc/grpc-go.git", +) diff --git a/checker/BUILD.bazel b/checker/BUILD.bazel index 7b5cfae7..31c8bfa8 100644 --- a/checker/BUILD.bazel +++ b/checker/BUILD.bazel @@ -20,10 +20,10 @@ go_library( "//common/packages:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", - "//io:checked_proto", - "//io:syntax_proto", "//parser:go_default_library", "@com_github_golang_protobuf//proto:go_default_library", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@io_bazel_rules_go//proto/wkt:empty_go_proto", "@io_bazel_rules_go//proto/wkt:struct_go_proto", ], diff --git a/checker/decls/BUILD.bazel b/checker/decls/BUILD.bazel index 006cc763..c515e4cb 100644 --- a/checker/decls/BUILD.bazel +++ b/checker/decls/BUILD.bazel @@ -7,8 +7,8 @@ go_library( "scopes.go", ], deps = [ - "//io:checked_proto", - "//io:syntax_proto", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@io_bazel_rules_go//proto/wkt:empty_go_proto", "@io_bazel_rules_go//proto/wkt:struct_go_proto", ], diff --git a/common/BUILD.bazel b/common/BUILD.bazel index 608b0459..583a3106 100644 --- a/common/BUILD.bazel +++ b/common/BUILD.bazel @@ -9,6 +9,9 @@ go_library( "source.go", ], importpath = "github.com/google/cel-go/common", + deps = [ + "@com_google_cel_spec//proto/v1:syntax_go_proto", + ], visibility = ["//visibility:public"], ) diff --git a/common/debug/BUILD.bazel b/common/debug/BUILD.bazel index 5670603b..547c306f 100644 --- a/common/debug/BUILD.bazel +++ b/common/debug/BUILD.bazel @@ -9,6 +9,6 @@ go_library( visibility = ["//visibility:public"], deps = [ "//common:go_default_library", - "//io:syntax_proto", + "@com_google_cel_spec//proto/v1:syntax_go_proto", ], ) diff --git a/common/source.go b/common/source.go index 3bd56c47..c39544e6 100644 --- a/common/source.go +++ b/common/source.go @@ -16,6 +16,8 @@ package common import ( "strings" + + "github.com/google/cel-spec/proto/v1/syntax" ) // Source interface for filter source contents. @@ -35,26 +37,37 @@ type Source interface { // The raw character offset at which the a location exists given the // location line and column. // Returns the line offset and whether the location was found. - CharacterOffset(location Location) (int32, bool) + LocationOffset(location Location) (int32, bool) - // LocationFromOffset translates a raw character offset to a Location, or + // OffsetLocation translates a character offset to a Location, or // false if the conversion was not feasible. - LocationFromOffset(offset int32) (Location, bool) + OffsetLocation(offset int32) (Location, bool) - // Return a line of content from the source and whether the line was found. + // Return a line of content and whether the line was found. Snippet(line int) (string, bool) -} -// Ensure the StringSource implements the Source interface. -var _ Source = &StringSource{} + // IdOffset returns the raw character offset of an expression within + // the source, or false if the expression cannot be found. + IdOffset(exprId int64) (int32, bool) -// StringSource type implementation of the Source interface. -type StringSource struct { + // IdLocation returns a Location for the given expression id, + // or false if one cannot be found. It behaves as the obvious + // composition of IdOffset() and OffsetLocation(). + IdLocation(exprId int64) (Location, bool) +} + +// The sourceImpl type implementation of the Source interface. +type sourceImpl struct { contents string description string lineOffsets []int32 + idOffsets map[int64]int32 } +// TODO(jimlarson) "Character offsets" should index the code points +// within the UTF-8 encoded string. It currently indexes bytes. +// Can be accomplished by using rune[] instead of string for contents. + // Create a new Source given the string contents and description. func NewStringSource(contents string, description string) Source { // Compute line offsets up front as they are referred to frequently. @@ -65,49 +78,78 @@ func NewStringSource(contents string, description string) Source { offset = offset + int32(len(line)) + 1 offsets[int32(i)] = offset } - return &StringSource{ + return &sourceImpl{ contents: contents, description: description, lineOffsets: offsets, + idOffsets: map[int64]int32{}, } } -func (s *StringSource) Content() string { +func NewInfoSource(info *syntax.SourceInfo) Source { + return &sourceImpl{ + contents: "", + description: info.Location, + lineOffsets: info.LineOffsets, + idOffsets: info.Positions, + } +} + +func (s *sourceImpl) Content() string { return s.contents } -func (s *StringSource) Description() string { +func (s *sourceImpl) Description() string { return s.description } -func (s *StringSource) LineOffsets() []int32 { +func (s *sourceImpl) LineOffsets() []int32 { return s.lineOffsets } -func (s *StringSource) CharacterOffset(location Location) (int32, bool) { +func (s *sourceImpl) LocationOffset(location Location) (int32, bool) { if lineOffset, found := s.findLineOffset(location.Line()); found { return lineOffset + int32(location.Column()), true } return -1, false } -func (s *StringSource) LocationFromOffset(offset int32) (Location, bool) { +func (s *sourceImpl) OffsetLocation(offset int32) (Location, bool) { line, lineOffset := s.findLine(offset) return NewLocation(int(line), int(offset-lineOffset)), true } -func (s *StringSource) Snippet(line int) (string, bool) { - if charStart, found := s.findLineOffset(line); found { - charEnd, found := s.findLineOffset(line + 1) - if found { - return s.contents[charStart : charEnd-1], true +func (s *sourceImpl) Snippet(line int) (string, bool) { + charStart, found := s.findLineOffset(line) + if !found || len(s.contents) == 0 { + return "", false + } + charEnd, found := s.findLineOffset(line + 1) + if found { + return s.contents[charStart : charEnd-1], true + } + return s.contents[charStart:], true +} + +func (s *sourceImpl) IdOffset(exprId int64) (int32, bool) { + if offset, found := s.idOffsets[exprId]; found { + return offset, true + } + return -1, false +} + +func (s *sourceImpl) IdLocation(exprId int64) (Location, bool) { + if offset, found := s.IdOffset(exprId); found { + if location, found := s.OffsetLocation(offset); found { + return location, true } - return s.contents[charStart:], true } - return "", false + return NewLocation(1, 0), false } -func (s *StringSource) findLineOffset(line int) (int32, bool) { +// findLineOffset returns the offset where the (1-indexed) line begins, +// or false if line doesn't exist. +func (s *sourceImpl) findLineOffset(line int) (int32, bool) { if line == 1 { return 0, true } else if line > 1 && line <= int(len(s.lineOffsets)) { @@ -117,7 +159,11 @@ func (s *StringSource) findLineOffset(line int) (int32, bool) { return -1, false } -func (s *StringSource) findLine(characterOffset int32) (int32, int32) { +// findLine finds the line that contains the given character offset and +// returns the line number and offset of the beginning of that line. +// Note that the last line is treated as if it contains all offsets +// beyond the end of the actual source. +func (s *sourceImpl) findLine(characterOffset int32) (int32, int32) { var line int32 = 1 for _, lineOffset := range s.lineOffsets { if lineOffset > characterOffset { diff --git a/common/source_test.go b/common/source_test.go index 3bbc9058..161fac3d 100644 --- a/common/source_test.go +++ b/common/source_test.go @@ -54,9 +54,9 @@ func TestStringSource_Description(t *testing.T) { } } -// TestStringSource_CharacterOffset make sure that the offsets accurately reflect +// TestStringSource_LocationOffset make sure that the offsets accurately reflect // the location of a character in source. -func TestStringSource_CharacterOffset(t *testing.T) { +func TestStringSource_LocationOffset(t *testing.T) { contents := "c.d &&\n\t b.c.arg(10) &&\n\t test(10)" source := NewStringSource(contents, "offset-test") expectedLineOffsets := []int32{7, 24, 35} @@ -73,14 +73,14 @@ func TestStringSource_CharacterOffset(t *testing.T) { } // Ensure that selecting a set of characters across multiple lines works as // expected. - charStart, _ := source.CharacterOffset(NewLocation(1, 2)) - charEnd, _ := source.CharacterOffset(NewLocation(3, 2)) + charStart, _ := source.LocationOffset(NewLocation(1, 2)) + charEnd, _ := source.LocationOffset(NewLocation(3, 2)) if "d &&\n\t b.c.arg(10) &&\n\t " != string(contents[charStart:charEnd]) { t.Errorf(unexpectedValue, t.Name(), string(contents[charStart:charEnd]), "d &&\n\t b.c.arg(10) &&\n\t ") } - if _, found := source.CharacterOffset(NewLocation(4, 0)); found { + if _, found := source.LocationOffset(NewLocation(4, 0)); found { t.Error("Character offset was out of range of source, but still found.") } } diff --git a/common/types/BUILD.bazel b/common/types/BUILD.bazel index dec16826..33d596cb 100644 --- a/common/types/BUILD.bazel +++ b/common/types/BUILD.bazel @@ -35,7 +35,7 @@ go_library( "//common/types/ref:go_default_library", "//common/types/pb:go_default_library", "//common/types/traits:go_default_library", - "//io:checked_proto", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_golang_protobuf//ptypes:go_default_library", "@io_bazel_rules_go//proto/wkt:any_go_proto", @@ -68,11 +68,11 @@ go_test( embed = [":go_default_library"], deps = [ "//common/types/ref:go_default_library", - "//io:syntax_proto", "//test:go_default_library", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@com_github_golang_protobuf//jsonpb:go_default_library", "@io_bazel_rules_go//proto/wkt:any_go_proto", "@io_bazel_rules_go//proto/wkt:duration_go_proto", "@io_bazel_rules_go//proto/wkt:timestamp_go_proto", ], -) \ No newline at end of file +) diff --git a/common/types/list.go b/common/types/list.go index 9c03e0cb..71c4037a 100644 --- a/common/types/list.go +++ b/common/types/list.go @@ -16,10 +16,11 @@ package types import ( "fmt" + "reflect" + "github.com/golang/protobuf/ptypes/struct" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" - "reflect" ) var ( @@ -33,6 +34,8 @@ var ( ) // NewDynamicList returns a traits.Lister with heterogenous elements. +// value should be an array of "native" types, i.e. any type that +// NativeToValue() can convert to a ref.Value. func NewDynamicList(value interface{}) traits.Lister { return &baseList{value, reflect.ValueOf(value)} } @@ -44,7 +47,15 @@ func NewStringList(elems []string) traits.Lister { elems: elems} } +// NewValueList returns a traits.Lister with ref.Value elements. +func NewValueList(elems []ref.Value) traits.Lister { + return &valueList{ + baseList: NewDynamicList(elems).(*baseList), + elems: elems} +} + // baseList points to a list containing elements of any type. +// value is an array of native values, and refValue is its reflection object. type baseList struct { value interface{} refValue reflect.Value @@ -364,6 +375,48 @@ func (l *stringList) Size() ref.Value { return Int(len(l.elems)) } +// valueList is a specialization of traits.Lister for ref.Value. +type valueList struct { + *baseList + elems []ref.Value +} + +func (l *valueList) Add(other ref.Value) ref.Value { + if other.Type() != ListType { + return NewErr("no such overload") + } + return &concatList{ + prevList: l, + nextList: other.(traits.Lister)} +} + +func (l *valueList) ConvertToNative(typeDesc reflect.Type) (interface{}, error) { + natives := make([]interface{}, len(l.elems)) + for _, v := range l.elems { + if n, e := v.ConvertToNative(typeDesc); e != nil { + return nil, e + } else { + natives = append(natives, n) + } + } + return natives, nil +} + +func (l *valueList) Get(index ref.Value) ref.Value { + if index.Type() != IntType { + return NewErr("unsupported index type '%s' in list", index.Type()) + } + i := index.(Int) + if i < 0 || i >= l.Size().(Int) { + return NewErr("index '%d' out of range in list size '%d'", i, l.Size()) + } + return l.elems[i] +} + +func (l *valueList) Size() ref.Value { + return Int(len(l.elems)) +} + type listIterator struct { *baseIterator listValue traits.Lister diff --git a/common/types/list_test.go b/common/types/list_test.go index 2a5e8321..081557b0 100644 --- a/common/types/list_test.go +++ b/common/types/list_test.go @@ -18,6 +18,7 @@ import ( "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/duration" + "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" "reflect" "testing" @@ -80,7 +81,23 @@ func TestBaseList_Equal(t *testing.T) { } func TestBaseList_Get(t *testing.T) { - list := NewDynamicList([]int32{1, 2, 3}) + validateList123(t, NewDynamicList([]int32{1, 2, 3}).(traits.Lister)) +} + +func TestValueList_Get(t *testing.T) { + validateList123(t, NewValueList([]ref.Value{Int(1), Int(2), Int(3)})) +} + +func TestBaseList_Iterator(t *testing.T) { + validateIterator123(t, NewDynamicList([]int32{1, 2, 3}).(traits.Lister)) +} + +func TestValueListValue_Iterator(t *testing.T) { + validateIterator123(t, NewValueList([]ref.Value{Int(1), Int(2), Int(3)})) +} + +func validateList123(t *testing.T, list traits.Lister) { + t.Helper() if getElem(t, list, 0) != Int(1) || getElem(t, list, 1) != Int(2) || getElem(t, list, 2) != Int(3) { @@ -97,8 +114,8 @@ func TestBaseList_Get(t *testing.T) { } } -func TestBaseList_Iterator(t *testing.T) { - list := NewDynamicList([]int32{1, 2, 3}) +func validateIterator123(t *testing.T, list traits.Lister) { + t.Helper() it := list.Iterator() var i = int64(0) for ; it.HasNext() == True; i++ { diff --git a/common/types/pb/BUILD.bazel b/common/types/pb/BUILD.bazel index 3c428a63..f26cb694 100644 --- a/common/types/pb/BUILD.bazel +++ b/common/types/pb/BUILD.bazel @@ -13,7 +13,7 @@ go_library( ], importpath = "github.com/google/cel-go/common/types/pb", deps = [ - "//io:checked_proto", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", "@com_github_golang_protobuf//descriptor:go_default_library", "@com_github_golang_protobuf//proto:go_default_library", "@io_bazel_rules_go//proto/wkt:descriptor_go_proto", diff --git a/common/types/ref/BUILD.bazel b/common/types/ref/BUILD.bazel index 1b7154aa..c70beb58 100644 --- a/common/types/ref/BUILD.bazel +++ b/common/types/ref/BUILD.bazel @@ -10,6 +10,6 @@ go_library( ], importpath = "github.com/google/cel-go/common/types/ref", deps = [ - "//io:checked_proto", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", ], ) diff --git a/interpreter/BUILD.bazel b/interpreter/BUILD.bazel index cf6b67d9..b6525bd9 100644 --- a/interpreter/BUILD.bazel +++ b/interpreter/BUILD.bazel @@ -13,7 +13,7 @@ go_library( "interpreter.go", "metadata.go", "program.go", - "prune.go", + "prune.go", ], importpath = "github.com/google/cel-go/interpreter", deps = [ @@ -25,8 +25,8 @@ go_library( "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", "//interpreter/functions:go_default_library", - "//io:checked_proto", - "//io:syntax_proto", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@com_github_golang_protobuf//proto:go_default_library", "@io_bazel_rules_go//proto/wkt:duration_go_proto", "@io_bazel_rules_go//proto/wkt:struct_go_proto", @@ -42,7 +42,7 @@ go_test( "evalstate_test.go", "interpreter_test.go", "program_test.go", - "prune_test.go", + "prune_test.go", ], embed = [ ":go_default_library", @@ -54,9 +54,9 @@ go_test( "//common/packages:go_default_library", "//common/types:go_default_library", "//interpreter/functions:go_default_library", - "//io:syntax_proto", "//parser:go_default_library", "//test:go_default_library", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@com_github_golang_protobuf//proto:go_default_library", "@io_bazel_rules_go//proto/wkt:duration_go_proto", "@io_bazel_rules_go//proto/wkt:struct_go_proto", diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index f2554392..1c7484c4 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -31,7 +31,7 @@ type Interpreter interface { NewInterpretable(program Program) Interpretable } -// Interpretable can accept a given Activation and produce a result along with +// Interpretable can accept a given Activation and produce a value along with // an accompanying EvalState which can be used to inspect whether additional // data might be necessary to complete the evaluation. type Interpretable interface { diff --git a/interpreter/metadata.go b/interpreter/metadata.go index 7c71a901..b4f1f198 100644 --- a/interpreter/metadata.go +++ b/interpreter/metadata.go @@ -19,12 +19,14 @@ import ( ) // Metadata interface for accessing position information about expressions. +// TODO(jimlarson) Replace with common.Source. type Metadata interface { - // CharacterOffset returns the raw character offset of an expression within + // IdOffset returns raw character offset of an expression within // Source, or false if the expression cannot be found. - CharacterOffset(exprId int64) (int32, bool) + IdOffset(exprId int64) (int32, bool) - // Location returns a common.Location for the given expression id, or false - // if one cannot be found. - Location(exprId int64) (common.Location, bool) + // IdLocation returns a common.Location for the given expression id, + // or false if one cannot be found. It behaves as the obvious + // composition of IdOffset() and OffsetLocation(). + IdLocation(exprId int64) (common.Location, bool) } diff --git a/interpreter/program.go b/interpreter/program.go index 7c0c2c29..f8940878 100644 --- a/interpreter/program.go +++ b/interpreter/program.go @@ -163,8 +163,8 @@ func newExprMetadata(info *expr.SourceInfo) Metadata { return &exprMetadata{info: info} } -func (m *exprMetadata) Location(exprId int64) (common.Location, bool) { - if exprOffset, found := m.CharacterOffset(exprId); found { +func (m *exprMetadata) IdLocation(exprId int64) (common.Location, bool) { + if exprOffset, found := m.IdOffset(exprId); found { var index = 0 var lineIndex = 0 var lineOffset int32 = 0 @@ -181,7 +181,7 @@ func (m *exprMetadata) Location(exprId int64) (common.Location, bool) { return nil, false } -func (m *exprMetadata) CharacterOffset(exprId int64) (int32, bool) { +func (m *exprMetadata) IdOffset(exprId int64) (int32, bool) { position, found := m.info.Positions[exprId] return position, found } diff --git a/interpreter/program_test.go b/interpreter/program_test.go index a1eab097..ae275666 100644 --- a/interpreter/program_test.go +++ b/interpreter/program_test.go @@ -25,7 +25,7 @@ func TestNewProgram_Empty(t *testing.T) { program := NewProgram( test.Empty.Expr, test.Empty.Info(t.Name())) - if loc, found := program.Metadata().Location(0); found { + if loc, found := program.Metadata().IdLocation(0); found { t.Errorf("Unexpected location found: %v", loc) } state := NewEvalState(program.MaxInstructionId() + 1) @@ -39,7 +39,7 @@ func TestNewProgram_LogicalAnd(t *testing.T) { program := NewProgram( test.LogicalAnd.Expr, test.LogicalAnd.Info(t.Name())) - if loc, found := program.Metadata().Location(1); found { + if loc, found := program.Metadata().IdLocation(1); found { t.Errorf("Unexpected location found: %v", loc) } state := NewEvalState(program.MaxInstructionId() + 1) @@ -54,7 +54,7 @@ func TestNewProgram_Conditional(t *testing.T) { program := NewProgram( test.Conditional.Expr, test.Conditional.Info(t.Name())) - if loc, found := program.Metadata().Location(1); found { + if loc, found := program.Metadata().IdLocation(1); found { t.Errorf("Unexpected location found: %v", loc) } state := NewEvalState(program.MaxInstructionId() + 1) @@ -85,7 +85,7 @@ func TestNewProgram_Comprehension(t *testing.T) { program := NewProgram( test.Exists.Expr, test.Exists.Info(t.Name())) - if loc, found := program.Metadata().Location(1); !found { + if loc, found := program.Metadata().IdLocation(1); !found { t.Errorf("Unexpected location found: %v", loc) } state := NewEvalState(program.MaxInstructionId() + 1) @@ -100,7 +100,7 @@ func TestNewProgram_DynMap(t *testing.T) { program := NewProgram( test.DynMap.Expr, test.DynMap.Info(t.Name())) - if loc, found := program.Metadata().Location(1); found { + if loc, found := program.Metadata().IdLocation(1); found { t.Errorf("Unexpected location found: %v", loc) } state := NewEvalState(program.MaxInstructionId() + 1) diff --git a/io/BUILD.bazel b/io/BUILD.bazel deleted file mode 100644 index 2eea0af7..00000000 --- a/io/BUILD.bazel +++ /dev/null @@ -1,28 +0,0 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") - -package( - default_visibility = [ - "//checker:__subpackages__", - "//common:__subpackages__", - "//interpreter:__subpackages__", - "//parser:__subpackages__", - "//semantics:__subpackages__", - "//test:__subpackages__", - ], -) - -go_proto_library( - name = "checked_proto", - importpath = "github.com/google/cel-spec/proto/checked/v1/checked", - proto = "@com_google_cel_spec//proto/checked/v1:checked_protos", - deps = [ - ":syntax_proto" - ], -) - -go_proto_library( - name = "syntax_proto", - importpath = "github.com/google/cel-spec/proto/v1/syntax", - proto = "@com_google_cel_spec//proto/v1:syntax_protos", -) \ No newline at end of file diff --git a/parser/BUILD.bazel b/parser/BUILD.bazel index a1184fc9..0a42c833 100644 --- a/parser/BUILD.bazel +++ b/parser/BUILD.bazel @@ -11,10 +11,10 @@ go_library( ], importpath = "github.com/google/cel-go/parser", deps = [ - "//io:syntax_proto", "//common:go_default_library", "//common/operators:go_default_library", "//parser/gen:go_default_library", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@com_github_antlr//runtime/Go/antlr:go_default_library", "@io_bazel_rules_go//proto/wkt:struct_go_proto", ], diff --git a/parser/helper.go b/parser/helper.go index bd30f903..4e893a50 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -202,14 +202,14 @@ func (p *parserHelper) id(ctx interface{}) int64 { } location := common.NewLocation(token.GetLine(), token.GetColumn()) id := p.nextId - p.positions[id], _ = p.source.CharacterOffset(location) + p.positions[id], _ = p.source.LocationOffset(location) p.nextId++ return id } func (p *parserHelper) getLocation(id int64) common.Location { characterOffset := p.positions[id] - location, _ := p.source.LocationFromOffset(characterOffset) + location, _ := p.source.OffsetLocation(characterOffset) return location } diff --git a/parser/macro.go b/parser/macro.go index 11d20bec..6651c6e9 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -141,7 +141,7 @@ func makeQuantifier(kind quantifierKind, p *parserHelper, ctx interface{}, targe v, found := extractIdent(args[0]) if !found { offset := p.positions[args[0].Id] - location, _ := p.source.LocationFromOffset(offset) + location, _ := p.source.OffsetLocation(offset) return p.reportError(location, "argument must be a simple name") } accuIdent := func() *expr.Expr { diff --git a/server/BUILD.bazel b/server/BUILD.bazel new file mode 100644 index 00000000..3c6f338a --- /dev/null +++ b/server/BUILD.bazel @@ -0,0 +1,59 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test") +load("@org_pubref_rules_protobuf//go:rules.bzl", "GRPC_COMPILE_DEPS") + +go_library( + name = "go_default_library", + srcs = [ + "server.go", + ], + importpath = "github.com/google/cel-go/server", + deps = [ + "//checker:go_default_library", + "//common:go_default_library", + "//common/packages:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", + "//common/types/traits:go_default_library", + "//interpreter:go_default_library", + "//parser:go_default_library", + "@com_google_cel_spec//proto/v1:cel_service_go_proto", + "@com_google_cel_spec//proto/v1:eval_go_proto", + "@com_google_cel_spec//proto/v1:value_go_proto", + "@com_google_googleapis//:rpc_status_go_proto", + "@com_github_golang_protobuf//ptypes:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", + ] + GRPC_COMPILE_DEPS, + visibility = ["//visibility:public"], +) + +go_test( + name = "go_default_test", + srcs = ["server_test.go"], + data = [":server_bin"], + deps = [ + ":go_default_library", + "//checker/decls:go_default_library", + "//common/operators:go_default_library", + "//test:go_default_library", + "@com_google_cel_spec//proto/checked/v1:checked_go_proto", + "@com_google_cel_spec//proto/v1:cel_service_go_proto", + "@com_google_cel_spec//proto/v1:eval_go_proto", + "@com_google_cel_spec//proto/v1:syntax_go_proto", + "@com_google_cel_spec//proto/v1:value_go_proto", + "@com_google_googleapis//:rpc_status_go_proto", + ] + GRPC_COMPILE_DEPS, + visibility = ["//visibility:public"], +) + +go_binary( + name = "server_bin", + srcs = ["main.go"], + deps = [ + "//server:go_default_library", + "@com_google_cel_spec//proto/v1:cel_service_go_proto", + "@org_golang_google_grpc//reflection:go_default_library", + ] + GRPC_COMPILE_DEPS, + out = "cel_server", + visibility = ["//visibility:public"], +) diff --git a/server/main.go b/server/main.go new file mode 100644 index 00000000..ea7da4ad --- /dev/null +++ b/server/main.go @@ -0,0 +1,38 @@ +package main + +import ( + "fmt" + "log" + "net" + "os" + + "github.com/google/cel-go/server" + "github.com/google/cel-spec/proto/v1/cel_service" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +func main() { + log.Println("Server opening listening port") + lis, err := net.Listen("tcp", "127.0.0.1:") + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + log.Println("Server opened port ", lis.Addr()) + + // Must print to stdout, so the client can find the port. + // So, no, this must be 'fmt', not 'log'. + fmt.Printf("Listening on %v\n", lis.Addr()) + os.Stdout.Sync() + log.Println("Server wrote address") + + log.Println("Server registering service on port") + s := grpc.NewServer() + cel_service.RegisterCelServiceServer(s, &server.CelServer{}) + log.Println("Server calling Register") + reflection.Register(s) + log.Println("Server calling Serve") + if err := s.Serve(lis); err != nil { + log.Fatalf("failed to serve: %v", err) + } +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 00000000..e6d4051d --- /dev/null +++ b/server/server.go @@ -0,0 +1,367 @@ +package server + +import ( + "context" + "fmt" + + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/packages" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" + "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/parser" + cspb "github.com/google/cel-spec/proto/v1/cel_service" + "github.com/google/cel-spec/proto/v1/eval" + "github.com/google/cel-spec/proto/v1/value" + "github.com/googleapis/googleapis/google/rpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type CelServer struct{} + +func (s *CelServer) Parse(ctx context.Context, in *cspb.ParseRequest) (*cspb.ParseResponse, error) { + if in.CelSource == "" { + st := status.New(codes.InvalidArgument, "No source code.") + return nil, st.Err() + } + // NOTE: syntax_version isn't currently used + src := common.NewStringSource(in.CelSource, in.SourceLocation) + var macs parser.Macros + if in.DisableMacros { + macs = parser.NoMacros + } else { + macs = parser.AllMacros + } + expr, errs := parser.Parse(src, macs) + resp := cspb.ParseResponse{} + if len(errs.GetErrors()) == 0 { + // Success + resp.ParsedExpr = expr + } else { + // Failure + appendErrors(errs, &resp.Issues) + } + return &resp, nil +} + +func (s *CelServer) Check(ctx context.Context, in *cspb.CheckRequest) (*cspb.CheckResponse, error) { + if in.ParsedExpr == nil { + st := status.New(codes.InvalidArgument, "No parsed expression.") + return nil, st.Err() + } + if in.ParsedExpr.SourceInfo == nil { + st := status.New(codes.InvalidArgument, "No source info.") + return nil, st.Err() + } + pkg := packages.NewPackage(in.Container) + typeProvider := types.NewProvider() + errs := common.NewErrors(common.NewInfoSource(in.ParsedExpr.SourceInfo)) + var env *checker.Env + if in.NoStdEnv { + env = checker.NewEnv(pkg, typeProvider, errs) + } else { + env = checker.NewStandardEnv(pkg, typeProvider, errs) + } + env.Add(in.TypeEnv...) + c := checker.Check(in.ParsedExpr, env) + resp := cspb.CheckResponse{} + if len(errs.GetErrors()) == 0 { + // Success + resp.CheckedExpr = c + } else { + // Failure + appendErrors(errs, &resp.Issues) + } + return &resp, nil +} + +func (s *CelServer) Eval(ctx context.Context, in *cspb.EvalRequest) (*cspb.EvalResponse, error) { + pkg := packages.NewPackage(in.Container) + typeProvider := types.NewProvider() + i := interpreter.NewStandardIntepreter(pkg, typeProvider) + var prog interpreter.Program + switch in.ExprKind.(type) { + case *cspb.EvalRequest_ParsedExpr: + parsed := in.GetParsedExpr() + prog = interpreter.NewProgram(parsed.Expr, parsed.SourceInfo) + case *cspb.EvalRequest_CheckedExpr: + prog = interpreter.NewCheckedProgram(in.GetCheckedExpr()) + default: + st := status.New(codes.InvalidArgument, "No expression.") + return nil, st.Err() + } + eval := i.NewInterpretable(prog) + args := make(map[string]interface{}) + for name, exprValue := range in.Bindings { + refVal, err := ExprValueToRefValue(exprValue) + if err != nil { + return nil, fmt.Errorf("can't convert binding %s: %s", name, err) + } + args[name] = refVal + } + // NOTE: the EvalState is currently discarded + result, _ := eval.Eval(interpreter.NewActivation(args)) + resultExprVal, err := RefValueToExprValue(result) + if err != nil { + return nil, fmt.Errorf("con't convert result: %s", err) + } + return &cspb.EvalResponse{Result: resultExprVal}, nil +} + +// appendErrors converts the errors from errs to Status messages +// and appends them to the list of issues. +func appendErrors(errs *common.Errors, issues *[]*rpc.Status) { + for _, e := range errs.GetErrors() { + status := ErrToStatus(e, cspb.StatusDetails_ERROR) + *issues = append(*issues, status) + } +} + +// ErrToStatus converts an Error to a Status message with the given severity. +func ErrToStatus(e common.Error, severity cspb.StatusDetails_Severity) *rpc.Status { + detail := cspb.StatusDetails{ + Severity: severity, + Line: int32(e.Location.Line()), + Column: int32(e.Location.Column()), + } + // TODO: simply use the following when we unify app-level + // and gRPC-level Status messages. + // return status.New(codes.InvalidArgument, e.message).WithDetails(detail).Proto() + s := rpc.Status{ + Code: int32(codes.InvalidArgument), + Message: e.Message, + } + any, err := ptypes.MarshalAny(&detail) + if err == nil { + s.Details = append(s.Details, any) + } + return &s +} + +// TODO(jimlarson): The following conversion code should be moved to +// common/types/provider.go and consolidated/refactored as appropriate. +// In particular, make judicious use of types.NativeToValue(). + +func RefValueToExprValue(res ref.Value) (*eval.ExprValue, error) { + if types.IsError(res) { + return &eval.ExprValue{ + Kind: &eval.ExprValue_Error{}}, nil + } + if types.IsUnknown(res) { + return &eval.ExprValue{ + Kind: &eval.ExprValue_Unknown{}}, nil + } + v, err := RefValueToValue(res) + if err != nil { + return nil, err + } + return &eval.ExprValue{ + Kind: &eval.ExprValue_Value{Value: v}}, nil +} + +var ( + typeNameToBasicType = map[string]value.TypeValue_BasicType{ + "bool": value.TypeValue_BOOL_TYPE, + "bytes": value.TypeValue_BYTES_TYPE, + "double": value.TypeValue_DOUBLE_TYPE, + "null_type": value.TypeValue_NULL_TYPE, + "int": value.TypeValue_INT_TYPE, + "list": value.TypeValue_LIST_TYPE, + "map": value.TypeValue_MAP_TYPE, + "string": value.TypeValue_STRING_TYPE, + "type": value.TypeValue_TYPE_TYPE, + "uint": value.TypeValue_UINT_TYPE, + } + basicTypeToTypeValue = map[value.TypeValue_BasicType]*types.TypeValue{ + value.TypeValue_NULL_TYPE: types.NullType, + value.TypeValue_BOOL_TYPE: types.BoolType, + value.TypeValue_INT_TYPE: types.IntType, + value.TypeValue_UINT_TYPE: types.UintType, + value.TypeValue_DOUBLE_TYPE: types.DoubleType, + value.TypeValue_STRING_TYPE: types.StringType, + value.TypeValue_BYTES_TYPE: types.BytesType, + value.TypeValue_TYPE_TYPE: types.TypeType, + value.TypeValue_MAP_TYPE: types.MapType, + value.TypeValue_LIST_TYPE: types.ListType, + } +) + +// Convert res, which must not be error or unknown, to a Value proto. +func RefValueToValue(res ref.Value) (*value.Value, error) { + switch res.Type() { + case types.BoolType: + return &value.Value{ + Kind: &value.Value_BoolValue{res.Value().(bool)}}, nil + case types.BytesType: + return &value.Value{ + Kind: &value.Value_BytesValue{res.Value().([]byte)}}, nil + case types.DoubleType: + return &value.Value{ + Kind: &value.Value_DoubleValue{res.Value().(float64)}}, nil + case types.IntType: + return &value.Value{ + Kind: &value.Value_Int64Value{res.Value().(int64)}}, nil + case types.ListType: + l := res.(traits.Lister) + sz := l.Size().(types.Int) + elts := make([]*value.Value, int64(sz)) + for i := types.Int(0); i < sz; i++ { + v, err := RefValueToValue(l.Get(i)) + if err != nil { + return nil, err + } + elts = append(elts, v) + } + return &value.Value{ + Kind: &value.Value_ListValue{ + &value.ListValue{Values: elts}}}, nil + case types.MapType: + mapper := res.(traits.Mapper) + sz := mapper.Size().(types.Int) + entries := make([]*value.MapValue_Entry, int64(sz)) + for it := mapper.Iterator(); it.HasNext().(types.Bool); { + k := it.Next() + v := mapper.Get(k) + kv, err := RefValueToValue(k) + if err != nil { + return nil, err + } + vv, err := RefValueToValue(v) + if err != nil { + return nil, err + } + entries = append(entries, &value.MapValue_Entry{Key: kv, Value: vv}) + } + return &value.Value{ + Kind: &value.Value_MapValue{ + &value.MapValue{Entries: entries}}}, nil + case types.NullType: + return &value.Value{ + Kind: &value.Value_NullValue{}}, nil + case types.StringType: + return &value.Value{ + Kind: &value.Value_StringValue{res.Value().(string)}}, nil + case types.TypeType: + typeName := res.(ref.Type).TypeName() + var tv *value.TypeValue + if basicType, found := typeNameToBasicType[typeName]; found { + // Names a basic type. + tv = &value.TypeValue{ + DesignatorKind: &value.TypeValue_BasicType_{basicType}} + } else { + // Otherwise names a proto. + tv = &value.TypeValue{ + DesignatorKind: &value.TypeValue_ObjectType{typeName}} + } + return &value.Value{Kind: &value.Value_TypeValue{tv}}, nil + case types.UintType: + return &value.Value{ + Kind: &value.Value_Uint64Value{res.Value().(uint64)}}, nil + default: + // Object type + pb, ok := res.Value().(proto.Message) + if !ok { + return nil, status.New(codes.InvalidArgument, "Expected proto message").Err() + } + any, err := ptypes.MarshalAny(pb) + if err != nil { + return nil, err + } + return &value.Value{ + Kind: &value.Value_ObjectValue{any}}, nil + } + return nil, status.New(codes.InvalidArgument, "unknown ref.Value type").Err() +} + +func ExprValueToRefValue(ev *eval.ExprValue) (ref.Value, error) { + switch ev.Kind.(type) { + case *eval.ExprValue_Value: + return ValueToRefValue(ev.GetValue()) + case *eval.ExprValue_Error: + // An error ExprValue is a repeated set of rpc.Status + // messages, with no convention for the status details. + // To convert this to a types.Err, we need to convert + // these Status messages to a single string, and be + // able to decompose that string on output so we can + // round-trip arbitrary ExprValue messages. + // TODO(jimlarson) make a convention for this. + return types.NewErr("XXX add details later"), nil + case *eval.ExprValue_Unknown: + return types.Unknown(ev.GetUnknown().Exprs), nil + } + return nil, status.New(codes.InvalidArgument, "unknown ExprValue kind").Err() +} + +func ValueToRefValue(v *value.Value) (ref.Value, error) { + switch v.Kind.(type) { + case *value.Value_NullValue: + return types.NullValue, nil + case *value.Value_BoolValue: + return types.Bool(v.GetBoolValue()), nil + case *value.Value_Int64Value: + return types.Int(v.GetInt64Value()), nil + case *value.Value_Uint64Value: + return types.Uint(v.GetUint64Value()), nil + case *value.Value_DoubleValue: + return types.Double(v.GetDoubleValue()), nil + case *value.Value_StringValue: + return types.String(v.GetStringValue()), nil + case *value.Value_BytesValue: + return types.Bytes(v.GetBytesValue()), nil + case *value.Value_ObjectValue: + any := v.GetObjectValue() + var msg ptypes.DynamicAny + if err := ptypes.UnmarshalAny(any, &msg); err != nil { + return nil, err + } + return types.NewObject(msg.Message), nil + case *value.Value_MapValue: + m := v.GetMapValue() + entries := make(map[ref.Value]ref.Value) + for _, entry := range m.Entries { + key, err := ValueToRefValue(entry.Key) + if err != nil { + return nil, err + } + value, err := ValueToRefValue(entry.Value) + if err != nil { + return nil, err + } + entries[key] = value + } + return types.NewDynamicMap(entries), nil + case *value.Value_ListValue: + l := v.GetListValue() + elts := make([]ref.Value, len(l.Values)) + for i, e := range l.Values { + rv, err := ValueToRefValue(e) + if err != nil { + return nil, err + } + elts[i] = rv + } + return types.NewValueList(elts), nil + case *value.Value_TypeValue: + var t *value.TypeValue + t = v.GetTypeValue() + switch t.DesignatorKind.(type) { + case *value.TypeValue_BasicType_: + bt := t.GetBasicType() + tv, ok := basicTypeToTypeValue[bt] + if ok { + return tv, nil + } + return nil, status.New(codes.InvalidArgument, "unknown basic type").Err() + case *value.TypeValue_ObjectType: + o := t.GetObjectType() + return types.NewObjectTypeValue(o), nil + } + return nil, status.New(codes.InvalidArgument, "unknown type designator kind").Err() + } + return nil, status.New(codes.InvalidArgument, "unknown value").Err() +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 00000000..017df74e --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,408 @@ +package server + +import ( + "fmt" + "log" + "os" + "os/exec" + "testing" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/test" + "github.com/google/cel-spec/proto/checked/v1/checked" + "github.com/google/cel-spec/proto/v1/cel_service" + "github.com/google/cel-spec/proto/v1/eval" + "github.com/google/cel-spec/proto/v1/syntax" + "github.com/google/cel-spec/proto/v1/value" + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +type serverTest struct { + cmd *exec.Cmd + conn *grpc.ClientConn + client cel_service.CelServiceClient +} + +var ( + globals = serverTest{} +) + +func TestMain(m *testing.M) { + // Use a helper function to ensure we run shutdown() + // before calling os.Exit() + os.Exit(mainHelper(m)) +} + +func mainHelper(m *testing.M) int { + err := setup() + defer shutdown() + if err != nil { + // testing.M doesn't have a logging method. hmm... + log.Fatal(err) + return 1 + } + return m.Run() +} + +func setup() error { + globals.cmd = exec.Command("cel_server") + + out, err := globals.cmd.StdoutPipe() + if err != nil { + return err + } + globals.cmd.Stderr = os.Stderr // share our error stream + + log.Println("Starting server") + err = globals.cmd.Start() + if err != nil { + return err + } + + log.Println("Getting server's address") + var addr string + _, err = fmt.Fscanf(out, "Listening on %s\n", &addr) + out.Close() + if err != nil { + return err + } + + log.Println("Connecting to ", addr) + conn, err := grpc.Dial(addr, grpc.WithInsecure()) + if err != nil { + return err + } + globals.conn = conn + + log.Println("Creating service client") + globals.client = cel_service.NewCelServiceClient(conn) + return nil +} + +func shutdown() { + if globals.conn != nil { + globals.conn.Close() + globals.conn = nil + } + if globals.cmd != nil { + globals.cmd.Process.Kill() + globals.cmd.Wait() + globals.cmd = nil + } +} + +var ( + parsed = &syntax.ParsedExpr{ + Expr: test.ExprCall(1, operators.Add, + test.ExprLiteral(2, int64(1)), + test.ExprLiteral(3, int64(1))), + SourceInfo: &syntax.SourceInfo{ + Location: "the location", + Positions: map[int64]int32{ + 1: 0, + 2: 0, + 3: 4, + }, + }, + } +) + +func TestParse(t *testing.T) { + req := cel_service.ParseRequest{ + CelSource: "1 + 1", + } + res, err := globals.client.Parse(context.Background(), &req) + if err != nil { + t.Fatal(err) + } + if res == nil { + t.Fatal("Empty result") + } + if res.ParsedExpr == nil { + t.Fatal("Empty parsed expression in result") + } + // Could check against 'parsed' above, + // but the expression ids are arbitrary, + // and explicit comparison logic is about as + // much work as normalization would be. + if res.ParsedExpr.Expr == nil { + t.Fatal("Empty expression in result") + } + switch res.ParsedExpr.Expr.ExprKind.(type) { + case *syntax.Expr_CallExpr: + c := res.ParsedExpr.Expr.GetCallExpr() + if c.Target != nil { + t.Error("Call has target", c) + } + if c.Function != "_+_" { + t.Error("Wrong function", c) + } + if len(c.Args) != 2 { + t.Error("Too many or few args", c) + } + for i, a := range c.Args { + switch a.ExprKind.(type) { + case *syntax.Expr_LiteralExpr: + l := a.GetLiteralExpr() + switch l.LiteralKind.(type) { + case *syntax.Literal_Int64Value: + if l.GetInt64Value() != int64(1) { + t.Errorf("Arg %d wrong value: %v", i, a) + } + default: + t.Errorf("Arg %d not int: %v", i, a) + } + default: + t.Errorf("Arg %d not literal: %v", i, a) + } + } + default: + t.Error("Wrong expression type", res.ParsedExpr.Expr) + } +} + +func TestCheck(t *testing.T) { + // If TestParse() passes, it validates a good chunk + // of the server mechanisms for data conversion, so we + // won't be as fussy here.. + req := cel_service.CheckRequest{ + ParsedExpr: parsed, + } + res, err := globals.client.Check(context.Background(), &req) + if err != nil { + t.Fatal(err) + } + if res == nil { + t.Fatal("Empty result") + } + if res.CheckedExpr == nil { + t.Fatal("No checked expression") + } + tp, present := res.CheckedExpr.TypeMap[int64(1)] + if !present { + t.Fatal("No type for top level expression", res) + } + switch tp.TypeKind.(type) { + case *checked.Type_Primitive: + if tp.GetPrimitive() != checked.Type_INT64 { + t.Error("Bad top-level type", tp) + } + default: + t.Error("Bad top-level type", tp) + } +} + +func TestEval(t *testing.T) { + req := cel_service.EvalRequest{ + ExprKind: &cel_service.EvalRequest_ParsedExpr{parsed}, + } + res, err := globals.client.Eval(context.Background(), &req) + if err != nil { + t.Fatal(err) + } + if res == nil || res.Result == nil { + t.Fatal("Nil result") + } + switch res.Result.Kind.(type) { + case *eval.ExprValue_Value: + v := res.Result.GetValue() + switch v.Kind.(type) { + case *value.Value_Int64Value: + if v.GetInt64Value() != int64(2) { + t.Error("Wrong result for 1 + 1", v) + } + default: + t.Error("Wrong result value type", v) + } + default: + t.Fatal("Result not a value", res.Result) + } +} + +func TestFullUp(t *testing.T) { + preq := cel_service.ParseRequest{ + CelSource: "x + y", + } + pres, err := globals.client.Parse(context.Background(), &preq) + if err != nil { + t.Fatal(err) + } + parsedExpr := pres.ParsedExpr + if parsedExpr == nil { + t.Fatal("Empty parsed expression") + } + + creq := cel_service.CheckRequest{ + ParsedExpr: parsedExpr, + TypeEnv: []*checked.Decl{ + decls.NewIdent("x", decls.Int, nil), + decls.NewIdent("y", decls.Int, nil), + }, + } + cres, err := globals.client.Check(context.Background(), &creq) + if err != nil { + t.Fatal(err) + } + if cres == nil { + t.Fatal("Empty check result") + } + checkedExpr := cres.CheckedExpr + if checkedExpr == nil { + t.Fatal("No checked expression") + } + tp, present := checkedExpr.TypeMap[int64(1)] + if !present { + t.Fatal("No type for top level expression", cres) + } + switch tp.TypeKind.(type) { + case *checked.Type_Primitive: + if tp.GetPrimitive() != checked.Type_INT64 { + t.Error("Bad top-level type", tp) + } + default: + t.Error("Bad top-level type", tp) + } + + ereq := cel_service.EvalRequest{ + ExprKind: &cel_service.EvalRequest_CheckedExpr{checkedExpr}, + Bindings: map[string]*eval.ExprValue{ + "x": exprValueInt64(1), + "y": exprValueInt64(2), + }, + } + eres, err := globals.client.Eval(context.Background(), &ereq) + if err != nil { + t.Fatal(err) + } + if eres == nil || eres.Result == nil { + t.Fatal("Nil result") + } + switch eres.Result.Kind.(type) { + case *eval.ExprValue_Value: + v := eres.Result.GetValue() + switch v.Kind.(type) { + case *value.Value_Int64Value: + if v.GetInt64Value() != int64(3) { + t.Error("Wrong result for 1 + 2", v) + } + default: + t.Error("Wrong result value type", v) + } + default: + t.Fatal("Result not a value", eres.Result) + } +} + +func exprValueInt64(x int64) *eval.ExprValue { + return &eval.ExprValue{ + Kind: &eval.ExprValue_Value{ + &value.Value{ + Kind: &value.Value_Int64Value{x}, + }, + }, + } +} + +// expectEvalTrue parses, checks, and evaluates the CEL expression in source +// and checks that the result is the boolean value 'true'. +func expectEvalTrue(t *testing.T, source string) { + // Parse + preq := cel_service.ParseRequest{ + CelSource: source, + } + pres, err := globals.client.Parse(context.Background(), &preq) + if err != nil { + t.Fatal(err) + } + if pres == nil { + t.Fatal("Empty parse result") + } + parsedExpr := pres.ParsedExpr + if parsedExpr == nil { + t.Fatal("Empty parsed expression") + } + if parsedExpr.Expr == nil { + t.Fatal("Empty root expression") + } + rootId := parsedExpr.Expr.Id + + // Check + creq := cel_service.CheckRequest{ + ParsedExpr: parsedExpr, + } + cres, err := globals.client.Check(context.Background(), &creq) + if err != nil { + t.Fatal(err) + } + if cres == nil { + t.Fatal("Empty check result") + } + checkedExpr := cres.CheckedExpr + if checkedExpr == nil { + t.Fatal("No checked expression") + } + topType, present := checkedExpr.TypeMap[rootId] + if !present { + t.Fatal("No type for top level expression", cres) + } + switch topType.TypeKind.(type) { + case *checked.Type_Primitive: + if topType.GetPrimitive() != checked.Type_BOOL { + t.Error("Bad top-level type", topType) + } + default: + t.Error("Bad top-level type", topType) + } + + // Eval + ereq := cel_service.EvalRequest{ + ExprKind: &cel_service.EvalRequest_CheckedExpr{checkedExpr}, + } + eres, err := globals.client.Eval(context.Background(), &ereq) + if err != nil { + t.Fatal(err) + } + if eres == nil || eres.Result == nil { + t.Fatal("Nil result") + } + switch eres.Result.Kind.(type) { + case *eval.ExprValue_Value: + v := eres.Result.GetValue() + switch v.Kind.(type) { + case *value.Value_BoolValue: + if !v.GetBoolValue() { + t.Error("Wrong result", v) + } + default: + t.Error("Wrong result value type", v) + } + default: + t.Fatal("Result not a value", eres.Result) + } +} + +func TestCondTrue(t *testing.T) { + expectEvalTrue(t, "(true ? 'a' : 'b') == 'a'") +} + +func TestCondFalse(t *testing.T) { + expectEvalTrue(t, "(false ? 'a' : 'b') == 'b'") +} + +func TestMapOrderInsignificant(t *testing.T) { + expectEvalTrue(t, "{1: 'a', 2: 'b'} == {2: 'b', 1: 'a'}") +} + +func FailsTestOneMetaType(t *testing.T) { + expectEvalTrue(t, "type(type(1)) == type(type('foo'))") +} + +func FailsTestTypeType(t *testing.T) { + expectEvalTrue(t, "type(type) == type") +} + +func FailsTestNullTypeName(t *testing.T) { + expectEvalTrue(t, "type(null) == null_type") +} diff --git a/test/BUILD.bazel b/test/BUILD.bazel index f81fef28..48f54700 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -7,6 +7,7 @@ package( "//common:__subpackages__", "//interpreter:__subpackages__", "//parser:__subpackages__", + "//server:__subpackages__", ], ) @@ -20,8 +21,8 @@ go_library( importpath = "github.com/google/cel-go/test", deps = [ "//common/operators:go_default_library", - "//io:syntax_proto", "@com_github_golang_protobuf//proto:go_default_library", + "@com_google_cel_spec//proto/v1:syntax_go_proto", "@io_bazel_rules_go//proto/wkt:any_go_proto", "@io_bazel_rules_go//proto/wkt:duration_go_proto", "@io_bazel_rules_go//proto/wkt:struct_go_proto",