From 416f0793fa85b3bf5b1081e700a7b839f9f0c408 Mon Sep 17 00:00:00 2001 From: Adphi Date: Thu, 23 Sep 2021 20:39:29 +0200 Subject: [PATCH] add casttype support for scalar types --- go.mod | 1 + go.sum | 3 +- patch/casttype.go | 155 +++++++++++ patch/casttype_test.go | 257 +++++++++++++++++ patch/go.proto | 6 + patch/gopb/go.pb.go | 111 ++++---- patch/patcher.go | 84 +++++- tests/message/message.extensions.go | 17 ++ tests/message/message_casttypes.pb.go | 387 ++++++++++++++++++++++++++ tests/message/message_casttypes.proto | 26 ++ tests/message/message_renames.proto | 3 +- tests/message/message_test.go | 35 ++- tests/plugin/validate.pb.go | 70 ++--- tests/plugin/validate.pb.validate.go | 16 +- tests/plugin/validate.proto | 2 +- tests/plugin/validate_test.go | 7 +- 16 files changed, 1085 insertions(+), 95 deletions(-) create mode 100644 patch/casttype.go create mode 100644 patch/casttype_test.go create mode 100644 tests/message/message.extensions.go create mode 100644 tests/message/message_casttypes.pb.go create mode 100644 tests/message/message_casttypes.proto diff --git a/go.mod b/go.mod index 9a73c85..1a114cc 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/fatih/structtag v1.2.0 github.com/iancoleman/strcase v0.1.2 // indirect github.com/lyft/protoc-gen-star v0.5.2 // indirect + github.com/stretchr/testify v1.7.0 golang.org/x/tools v0.1.6 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index 0e75cfa..c5c7096 100644 --- a/go.sum +++ b/go.sum @@ -34,8 +34,9 @@ github.com/spf13/afero v1.3.4 h1:8q6vk3hthlpb2SouZcnBVKboxWQWMDNF38bwholZrJc= github.com/spf13/afero v1.3.4/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/patch/casttype.go b/patch/casttype.go new file mode 100644 index 0000000..02db2e2 --- /dev/null +++ b/patch/casttype.go @@ -0,0 +1,155 @@ +package patch + +import ( + "go/ast" + "go/types" + "log" + "strings" + + "golang.org/x/tools/go/ast/astutil" +) + +func (p *Patcher) patchTypeDef(id *ast.Ident, obj types.Object) { + castType, ok := p.fieldCastType[obj] + if !ok { + return + } + + parent := p.findParentNode(id) + if pkg, name := packageAndName(castType); pkg != "" { + f := p.fileOf(id) + pkgImport := strings.Replace(strings.Replace(pkg, "/", "_", -1), ".", "_", -1) + astutil.AddNamedImport(p.fset, f, pkgImport, pkg) + castType = pkgImport + "." + name + } + // Cast Field definition + if id.Obj != nil && id.Obj.Decl != nil { + v, ok := id.Obj.Decl.(*ast.Field) + if !ok { + log.Printf("Warning: casttype declared for non-field object: %v `%s`", obj, castType) + return + } + t, ok := v.Type.(*ast.Ident) + if ok { + t.Name = castType + return + } + } + switch obj.Type().(type) { + // Cast Getter signature + case *types.Signature: + n, ok := parent.(*ast.FuncDecl) + if !ok { + log.Printf("Warning: unexpected type for getter: %v `%T`", obj, parent) + break + } + if l := len(n.Type.Results.List); l != 1 { + log.Printf("Warning: unexpected return count for getter: %v `%d`", obj, l) + return + } + if ident, ok := n.Type.Results.List[0].Type.(*ast.Ident); ok { + ident.Name = castType + return + } + } +} + +func (p *Patcher) patchTypeUsage(id *ast.Ident, obj types.Object) { + desiredType, ok := p.fieldCastType[obj] + if !ok { + return + } + var originalType string + switch t := obj.Type().(type) { + case *types.Basic: + originalType = t.Name() + case *types.Signature: + if t.Results().Len() != 1 { + return + } + originalType = t.Results().At(0).Type().String() + } + usageNode := p.findParentNode(id) + pkgPath, pkgName := packageAndName(desiredType) + pkgImport := strings.Replace(strings.Replace(pkgPath, "/", "_", -1), ".", "_", -1) + if pkgPath != "" { + desiredType = pkgImport + "." + pkgName + } + cast := func(as string, expr ast.Expr) ast.Expr { + if pkgPath != "" && as == desiredType { + f := p.fileOf(id) + astutil.AddNamedImport(p.fset, f, pkgImport, pkgPath) + } + return &ast.CallExpr{ + Fun: &ast.Ident{ + Name: as, + }, + Args: []ast.Expr{expr}, + } + } + parentNode := p.findParentNode(usageNode) + + switch usage := usageNode.(type) { + case *ast.SelectorExpr: + switch parentExpr := parentNode.(type) { + case *ast.AssignStmt: + if len(parentExpr.Lhs) != len(parentExpr.Rhs) { + return + } + for i := range parentExpr.Lhs { + if parentExpr.Lhs[i] == usage { + parentExpr.Rhs[i] = cast(desiredType, parentExpr.Rhs[i]) + return + } + } + for i := range parentExpr.Rhs { + if parentExpr.Rhs[i] == usage { + parentExpr.Rhs[i] = cast(originalType, parentExpr.Rhs[i]) + return + } + } + case *ast.CallExpr: + parent := p.findParentNode(parentExpr) + assign, isAssign := parent.(*ast.AssignStmt) + if parentExpr.Fun == usage && isAssign { + for i := range assign.Rhs { + if assign.Rhs[i] == parentExpr { + assign.Rhs[i] = cast(originalType, assign.Rhs[i]) + return + } + } + } + call, isCall := parent.(*ast.CallExpr) + if isCall { + for i := range call.Args { + if call.Args[i] == parentExpr { + call.Args[i] = cast(originalType, call.Args[i]) + return + } + } + } + for i, v := range parentExpr.Args { + if v == usage { + parentExpr.Args[i] = cast(originalType, usage) + return + } + } + case *ast.BinaryExpr: + if parentExpr.X == usage { + parentExpr.X = cast(originalType, parentExpr.X) + } + if parentExpr.Y == usage { + parentExpr.Y = cast(originalType, parentExpr.Y) + } + } + case *ast.KeyValueExpr: + if usage.Key == id { + usage.Value = cast(desiredType, usage.Value) + return + } + if usage.Value == id { + usage.Value = cast(originalType, usage.Value) + return + } + } +} diff --git a/patch/casttype_test.go b/patch/casttype_test.go new file mode 100644 index 0000000..feb5e09 --- /dev/null +++ b/patch/casttype_test.go @@ -0,0 +1,257 @@ +package patch + +import ( + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/compiler/protogen" +) + +const ( + srcDef = `package foo + +import ( + "fmt" +) + +type String string + +type Message struct { + Content string +} + +func (m *Message) GetContent() string { + if m != nil { + return m.Content + } + return "" +} + +func print(s string) { + fmt.Println(s) +} +` + wantDef = `package foo + +import ( + "fmt" +) + +type String string + +type Message struct { + Content String +} + +func (m *Message) GetContent() String { + if m != nil { + return m.Content + } + return "" +} + +func print(s string) { + fmt.Println(s) +} +` +) + +const ( + fileName = "foo.go" + packageName = "foo" + fieldName = "Content" + msgName = "Message" + casttype = "String" +) + +func prepareCastType(src string) (*Patcher, *ast.File, error) { + p := &Patcher{ + fset: token.NewFileSet(), + info: &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + }, + filesByName: make(map[string]*ast.File), + packagesByPath: make(map[string]*Package), + packagesByName: make(map[string]*Package), + renames: make(map[protogen.GoIdent]string), + typeRenames: make(map[protogen.GoIdent]string), + valueRenames: make(map[protogen.GoIdent]string), + fieldRenames: make(map[protogen.GoIdent]string), + methodRenames: make(map[protogen.GoIdent]string), + objectRenames: make(map[types.Object]string), + tags: make(map[protogen.GoIdent]string), + fieldTags: make(map[types.Object]string), + embeds: make(map[protogen.GoIdent]string), + fieldEmbeds: make(map[types.Object]string), + castType: make(map[protogen.GoIdent]string), + fieldCastType: make(map[types.Object]string), + } + file, err := parser.ParseFile(p.fset, fileName, src, parser.ParseComments) + if err != nil { + return nil, nil, err + } + pkg := NewPackage(packageName, packageName) + if err := pkg.AddFile(fileName, file); err != nil { + return nil, nil, err + } + if err := pkg.Check(basicImporter{p}, p.fset, p.info); err != nil { + return nil, nil, err + } + p.filesByName[fileName] = file + p.packagesByPath[packageName] = pkg + p.packagesByName[packageName] = pkg + p.packages = append(p.packages, pkg) + p.CastType(protogen.GoIdent{GoName: msgName + "." + fieldName, GoImportPath: packageName}, casttype) + p.CastType(protogen.GoIdent{GoName: msgName + "." + "Get" + fieldName, GoImportPath: packageName}, casttype) + // Map cast types + for id, typ := range p.castType { + obj, _ := p.find(id) + if obj == nil { + continue + } + p.fieldCastType[obj] = typ + } + return p, file, nil +} + +func TestScalarCastType(t *testing.T) { + tests := []struct{ + name string + src string + want string + }{ + { + name: "cast definition", + src: srcDef, + want: wantDef, + }, + { + name: "cast field initialization", + src: srcDef+` +func useContent() { + s := "ok" + msg := &Message{Content: s} +} +`, + want: wantDef+` +func useContent() { + s := "ok" + msg := &Message{Content: String(s)} +} +`, + }, + { + name: "cast field assignation", + src: srcDef+` +func useContent() { + s := "ok" + msg := &Message{} + msg.Content = s +} +`, + want: wantDef+` +func useContent() { + s := "ok" + msg := &Message{} + msg.Content = String(s) +} +`, + }, + { + name: "cast field assignation usage", + src: srcDef+` +func useContent() { + var s string + msg := &Message{Content: "ok"} + s = msg.Content +} +`, + want: wantDef+` +func useContent() { + var s string + msg := &Message{Content: String("ok")} + s = string(msg.Content) +} +`, + }, + { + name: "cast field used as argument", + src: srcDef+` +func useContent() { + msg := &Message{Content: "ok"} + var s string + s = msg.Content +} +`, + want: wantDef+` +func useContent() { + msg := &Message{Content: String("ok")} + var s string + s = string(msg.Content) +} +`, + }, + { + name: "cast full code", + src: srcDef+` +func useContent() { + s := "ok" + msg := &Message{Content: s} + print(msg.Content) + print(msg.GetContent()) + msg.Content = s + useContentGetter(msg) +} + +func useContentGetter(msg *Message) { + var s string + s = msg.GetContent() + s = msg.Content + s = msg.Content + "..." + print(s) +} +`, + want: wantDef+` +func useContent() { + s := "ok" + msg := &Message{Content: String(s)} + print(string(msg.Content)) + print(string(msg.GetContent())) + msg.Content = String(s) + useContentGetter(msg) +} + +func useContentGetter(msg *Message) { + var s string + s = string(msg.GetContent()) + s = string(msg.Content) + s = string(msg.Content) + "..." + print(s) +} +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p, file, err := prepareCastType(tt.src) + if err != nil { + t.Errorf("failed to initilialize test") + return + } + + if err := p.patchGoFiles(); err != nil { + t.Fatal(err) + } + got := p.nodeToString(file) + assert.Equal(t, tt.want, got) + }) + } + +} diff --git a/patch/go.proto b/patch/go.proto index 31c5675..41d13e7 100644 --- a/patch/go.proto +++ b/patch/go.proto @@ -21,6 +21,12 @@ message Options { // See https://golang.org/ref/spec#Struct_types. optional bool embed = 2; + // The casstype option changes the generated field type. + // All generated code assumes that this type is castable to the protocol buffer field type, + // so it does not work for messages types. + // Not supported for repeated fields. + optional string casttype = 3; + // The getter option renames the generated getter method (default: Get) // so a custom getter can be implemented in its place. optional string getter = 10; // TODO: implement this diff --git a/patch/gopb/go.pb.go b/patch/gopb/go.pb.go index 45dc636..1c80d32 100644 --- a/patch/gopb/go.pb.go +++ b/patch/gopb/go.pb.go @@ -38,6 +38,11 @@ type Options struct { // Only message types can be embedded. Oneof fields cannot be embedded. // See https://golang.org/ref/spec#Struct_types. Embed *bool `protobuf:"varint,2,opt,name=embed" json:"embed,omitempty"` + // The casstype option changes the generated field type. + // All generated code assumes that this type is castable to the protocol buffer field type, + // so it does not work for messages types. + // Not supported for repeated fields. + Casttype *string `protobuf:"bytes,3,opt,name=casttype" json:"casttype,omitempty"` // The getter option renames the generated getter method (default: Get) // so a custom getter can be implemented in its place. Getter *string `protobuf:"bytes,10,opt,name=getter" json:"getter,omitempty"` // TODO: implement this @@ -99,6 +104,13 @@ func (x *Options) GetEmbed() bool { return false } +func (x *Options) GetCasttype() string { + if x != nil && x.Casttype != nil { + return *x.Casttype + } + return "" +} + func (x *Options) GetGetter() string { if x != nil && x.Getter != nil { return *x.Getter @@ -328,58 +340,59 @@ var file_patch_go_proto_rawDesc = []byte{ 0x0a, 0x0e, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x67, 0x6f, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa0, 0x01, 0x0a, 0x07, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xbc, 0x01, 0x0a, 0x07, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, - 0x67, 0x65, 0x74, 0x74, 0x65, 0x72, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x67, 0x65, - 0x74, 0x74, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x14, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x65, 0x72, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x74, 0x72, - 0x69, 0x6e, 0x67, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0xc3, 0x01, 0x0a, 0x0b, 0x4c, 0x69, - 0x6e, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x65, 0x6e, 0x75, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, - 0x65, 0x6e, 0x75, 0x6d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x12, 0x1e, 0x0a, - 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x20, 0x0a, - 0x0b, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, 0x18, 0x0a, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x0b, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, 0x3a, - 0x47, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, - 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x3a, 0x41, 0x0a, 0x05, 0x66, 0x69, 0x65, 0x6c, - 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x3a, 0x41, 0x0a, 0x05, 0x6f, - 0x6e, 0x65, 0x6f, 0x66, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, - 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x3a, 0x3e, - 0x0a, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x12, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x4f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, - 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x3a, 0x45, - 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, - 0x6c, 0x75, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x42, 0x0a, 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x12, 0x1c, 0x2e, + 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, + 0x63, 0x61, 0x73, 0x74, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x63, 0x61, 0x73, 0x74, 0x74, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x67, 0x65, 0x74, 0x74, + 0x65, 0x72, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x67, 0x65, 0x74, 0x74, 0x65, 0x72, + 0x12, 0x12, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x74, 0x61, 0x67, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, + 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, + 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, + 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0xc3, 0x01, 0x0a, 0x0b, 0x4c, 0x69, 0x6e, 0x74, 0x4f, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x65, + 0x6e, 0x75, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x6e, 0x75, 0x6d, + 0x73, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x78, 0x74, + 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x65, + 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x6e, 0x69, + 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, + 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, 0x3a, 0x47, 0x0a, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, + 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x3a, 0x41, 0x0a, 0x05, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x46, 0x69, 0x6c, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x67, 0x6f, 0x2e, 0x4c, 0x69, 0x6e, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x42, 0x27, 0x5a, 0x25, 0x67, 0x69, 0x74, - 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, - 0x70, 0x62, + 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x52, 0x05, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x3a, 0x41, 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, + 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x3a, 0x3e, 0x0a, 0x04, 0x65, 0x6e, + 0x75, 0x6d, 0x12, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x3a, 0x45, 0x0a, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, + 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x42, 0x0a, 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x12, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, + 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x6c, 0x65, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, + 0x2e, 0x67, 0x6f, 0x2e, 0x4c, 0x69, 0x6e, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, + 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x42, 0x27, 0x5a, 0x25, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, + 0x74, 0x63, 0x68, 0x2f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x70, 0x62, } var ( diff --git a/patch/patcher.go b/patch/patcher.go index 2c70f95..0b30ce7 100644 --- a/patch/patcher.go +++ b/patch/patcher.go @@ -6,9 +6,11 @@ import ( "go/ast" "go/format" "go/parser" + "go/printer" "go/token" "go/types" "log" + "os" "path/filepath" "regexp" "strings" @@ -48,7 +50,9 @@ type Patcher struct { tags map[protogen.GoIdent]string fieldTags map[types.Object]string embeds map[protogen.GoIdent]string + castType map[protogen.GoIdent]string fieldEmbeds map[types.Object]string + fieldCastType map[types.Object]string } // NewPatcher returns an initialized Patcher for gen. @@ -67,6 +71,8 @@ func NewPatcher(gen *protogen.Plugin) (*Patcher, error) { fieldTags: make(map[types.Object]string), embeds: make(map[protogen.GoIdent]string), fieldEmbeds: make(map[types.Object]string), + castType: make(map[protogen.GoIdent]string), + fieldCastType: make(map[types.Object]string), } return p, p.scan() } @@ -310,6 +316,22 @@ func (p *Patcher) scanField(f *protogen.Field) { p.RenameMethod(ident.WithChild(m.GoIdent, "Get"+f.GoName), "Get"+newName) // Getter } + // check casttype + if typ := opts.GetCasttype(); typ != "" { + switch { + case f.Desc.IsList(): + log.Printf("Warning: casttype declared for repeated field: %s", f.Desc.Name()) + case f.Message != nil: + log.Printf("Warning: casttype declared for message field: %s", f.Desc.Name()) + case f.Oneof != nil: + p.CastType(ident.WithChild(f.GoIdent, f.GoName), typ) + p.CastType(ident.WithChild(m.GoIdent, "Get"+f.GoName), typ) + default: + p.CastType(ident.WithChild(m.GoIdent, f.GoName), typ) + p.CastType(ident.WithChild(m.GoIdent, "Get"+f.GoName), typ) + } + } + // Add or replace any struct tags? tags := opts.GetTags() if tags != "" { @@ -397,6 +419,14 @@ func (p *Patcher) nameFor(id protogen.GoIdent) string { return ident.LeafName(id) } +// CastType casts the Go struct field as the desired castType/ +// The castType value must be a named type, e.g.: "type String string" +// It can also be a fully qualified type name, e.g.: "go.repo.com/mymodule/types.String" +func (p *Patcher) CastType(id protogen.GoIdent, castType string) { + p.castType[id] = castType + log.Printf("Cast type:\t%s.%s → %s", id.GoImportPath, id.GoName, castType) +} + // Tag adds the specified struct tags to the field specified by selector, // in the form of "Message.Field". The tags argument should omit outer backticks (`). // The value of id.GoName should be the original generated identifier name, not a renamed identifier. @@ -540,6 +570,15 @@ func (p *Patcher) checkGoFiles() error { } } + // Map cast types + for id, typ := range p.castType { + obj, _ := p.find(id) + if obj == nil { + continue + } + p.fieldCastType[obj] = typ + } + // Map struct tags. for id, tags := range p.tags { obj, _ := p.find(id) @@ -664,6 +703,7 @@ func (p *Patcher) serializeGoFiles(res *pluginpb.CodeGeneratorResponse) error { var b strings.Builder err := format.Node(&b, p.fset, f) if err != nil { + printer.Fprint(os.Stderr, p.fset, f) return err } @@ -676,6 +716,7 @@ func (p *Patcher) serializeGoFiles(res *pluginpb.CodeGeneratorResponse) error { func (p *Patcher) patchGoFiles() error { log.Printf("\nDefs") for id, obj := range p.info.Defs { + p.patchTypeDef(id, obj) p.patchIdent(id, obj, true) p.patchTags(id, obj) // if id.IsExported() { @@ -686,6 +727,7 @@ func (p *Patcher) patchGoFiles() error { log.Printf("\nUses\n") for id, obj := range p.info.Uses { + p.patchTypeUsage(id, obj) p.patchIdent(id, obj, false) } @@ -715,6 +757,26 @@ func (p *Patcher) patchIdent(id *ast.Ident, obj types.Object, isDecl bool) { } } +func (p *Patcher) nodeToString(n ast.Node) string { + b := &bytes.Buffer{} + if err := printer.Fprint(b, p.fset, n); err != nil { + log.Fatal(err) + } + return b.String() +} + +func (p *Patcher) findParentNode(id ast.Node) ast.Node { + var node ast.Node + astutil.Apply(p.fileOf(id), func(cursor *astutil.Cursor) bool { + if cursor.Node() == id { + node = cursor.Parent() + return false + } + return true + }, nil) + return node +} + func (p *Patcher) patchTags(id *ast.Ident, obj types.Object) { fieldTags := p.fieldTags[obj] if fieldTags == "" || id.Obj == nil { @@ -768,11 +830,7 @@ func (p *Patcher) patchComments(id *ast.Ident, repl string) { // Borrowed from https://github.com/golang/tools/blob/HEAD/refactor/rename/rename.go#L543 func (p *Patcher) findCommentGroups(id *ast.Ident) (doc *ast.CommentGroup, comment *ast.CommentGroup) { - tf := p.fset.File(id.Pos()) - if tf == nil { - return - } - f := p.filesByName[tf.Name()] + f := p.fileOf(id) if f == nil { return } @@ -802,6 +860,14 @@ func (p *Patcher) findCommentGroups(id *ast.Ident) (doc *ast.CommentGroup, comme return } +func (p *Patcher) fileOf(node ast.Node) *ast.File { + tf := p.fset.File(node.Pos()) + if tf == nil { + return nil + } + return p.filesByName[tf.Name()] +} + func patchCommentGroup(c *ast.CommentGroup, x *regexp.Regexp, repl string) { if c == nil { return @@ -834,3 +900,11 @@ func typeString(obj types.Object) string { } return obj.Type().String() } + +func packageAndName(fqn string) (pkg string, name string) { + parts := strings.Split(fqn, ".") + if len(parts) == 1 { + return "", fqn + } + return strings.Join(parts[:len(parts)-1], "."), parts[len(parts)-1] +} diff --git a/tests/message/message.extensions.go b/tests/message/message.extensions.go new file mode 100644 index 0000000..4718488 --- /dev/null +++ b/tests/message/message.extensions.go @@ -0,0 +1,17 @@ +package message + +type Name string + +type Int32 int32 + +type Int64 int64 + +type String string + +type Float float32 + +type Double float64 + +type Uint32 uint32 + +type Uint64 uint64 diff --git a/tests/message/message_casttypes.pb.go b/tests/message/message_casttypes.pb.go new file mode 100644 index 0000000..f5e986a --- /dev/null +++ b/tests/message/message_casttypes.pb.go @@ -0,0 +1,387 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.17.3 +// source: tests/message/message_casttypes.proto + +package message + +import ( + _ "github.com/alta/protopatch/patch/gopb" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MessageWithCustomTypes struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StringField String `protobuf:"bytes,1,opt,name=string_field,json=stringField,proto3" json:"string_field,omitempty"` + Int32Field Int32 `protobuf:"varint,2,opt,name=int32_field,json=int32Field,proto3" json:"int32_field,omitempty"` + Int64Field Int64 `protobuf:"varint,3,opt,name=int64_field,json=int64Field,proto3" json:"int64_field,omitempty"` + FloatField Float `protobuf:"fixed32,4,opt,name=float_field,json=floatField,proto3" json:"float_field,omitempty"` + DoubleField Double `protobuf:"fixed64,5,opt,name=double_field,json=doubleField,proto3" json:"double_field,omitempty"` + Uint32Field Uint32 `protobuf:"varint,6,opt,name=uint32_field,json=uint32Field,proto3" json:"uint32_field,omitempty"` + Uint64Field Uint64 `protobuf:"varint,7,opt,name=uint64_field,json=uint64Field,proto3" json:"uint64_field,omitempty"` +} + +func (x *MessageWithCustomTypes) Reset() { + *x = MessageWithCustomTypes{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_casttypes_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithCustomTypes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithCustomTypes) ProtoMessage() {} + +func (x *MessageWithCustomTypes) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_casttypes_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithCustomTypes.ProtoReflect.Descriptor instead. +func (*MessageWithCustomTypes) Descriptor() ([]byte, []int) { + return file_tests_message_message_casttypes_proto_rawDescGZIP(), []int{0} +} + +func (x *MessageWithCustomTypes) GetStringField() String { + if x != nil { + return x.StringField + } + return "" +} + +func (x *MessageWithCustomTypes) GetInt32Field() Int32 { + if x != nil { + return x.Int32Field + } + return 0 +} + +func (x *MessageWithCustomTypes) GetInt64Field() Int64 { + if x != nil { + return x.Int64Field + } + return 0 +} + +func (x *MessageWithCustomTypes) GetFloatField() Float { + if x != nil { + return x.FloatField + } + return 0 +} + +func (x *MessageWithCustomTypes) GetDoubleField() Double { + if x != nil { + return x.DoubleField + } + return 0 +} + +func (x *MessageWithCustomTypes) GetUint32Field() Uint32 { + if x != nil { + return x.Uint32Field + } + return 0 +} + +func (x *MessageWithCustomTypes) GetUint64Field() Uint64 { + if x != nil { + return x.Uint64Field + } + return 0 +} + +type MessageWithOneOfCustomType struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to OneOf: + // *MessageWithOneOfCustomType_StringField + // *MessageWithOneOfCustomType_Int64Field + OneOf isMessageWithOneOfCustomType_OneOf `protobuf_oneof:"one_of"` +} + +func (x *MessageWithOneOfCustomType) Reset() { + *x = MessageWithOneOfCustomType{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_casttypes_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithOneOfCustomType) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithOneOfCustomType) ProtoMessage() {} + +func (x *MessageWithOneOfCustomType) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_casttypes_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithOneOfCustomType.ProtoReflect.Descriptor instead. +func (*MessageWithOneOfCustomType) Descriptor() ([]byte, []int) { + return file_tests_message_message_casttypes_proto_rawDescGZIP(), []int{1} +} + +func (m *MessageWithOneOfCustomType) GetOneOf() isMessageWithOneOfCustomType_OneOf { + if m != nil { + return m.OneOf + } + return nil +} + +func (x *MessageWithOneOfCustomType) GetStringField() String { + if x, ok := x.GetOneOf().(*MessageWithOneOfCustomType_StringField); ok { + return x.StringField + } + return "" +} + +func (x *MessageWithOneOfCustomType) GetInt64Field() Int64 { + if x, ok := x.GetOneOf().(*MessageWithOneOfCustomType_Int64Field); ok { + return x.Int64Field + } + return 0 +} + +type isMessageWithOneOfCustomType_OneOf interface { + isMessageWithOneOfCustomType_OneOf() +} + +type MessageWithOneOfCustomType_StringField struct { + StringField String `protobuf:"bytes,1,opt,name=string_field,json=stringField,proto3,oneof"` +} + +type MessageWithOneOfCustomType_Int64Field struct { + Int64Field Int64 `protobuf:"varint,3,opt,name=int64_field,json=int64Field,proto3,oneof"` +} + +func (*MessageWithOneOfCustomType_StringField) isMessageWithOneOfCustomType_OneOf() {} + +func (*MessageWithOneOfCustomType_Int64Field) isMessageWithOneOfCustomType_OneOf() {} + +type MessageWithRepeatedCustomTypes struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StringField []string `protobuf:"bytes,1,rep,name=string_field,json=stringField,proto3" json:"string_field,omitempty"` +} + +func (x *MessageWithRepeatedCustomTypes) Reset() { + *x = MessageWithRepeatedCustomTypes{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_casttypes_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithRepeatedCustomTypes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithRepeatedCustomTypes) ProtoMessage() {} + +func (x *MessageWithRepeatedCustomTypes) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_casttypes_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithRepeatedCustomTypes.ProtoReflect.Descriptor instead. +func (*MessageWithRepeatedCustomTypes) Descriptor() ([]byte, []int) { + return file_tests_message_message_casttypes_proto_rawDescGZIP(), []int{2} +} + +func (x *MessageWithRepeatedCustomTypes) GetStringField() []string { + if x != nil { + return x.StringField + } + return nil +} + +var File_tests_message_message_casttypes_proto protoreflect.FileDescriptor + +var file_tests_message_message_casttypes_proto_rawDesc = []byte{ + 0x0a, 0x25, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2f, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x63, 0x61, 0x73, 0x74, 0x74, 0x79, 0x70, 0x65, + 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe6, 0x02, 0x0a, 0x16, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x57, 0x69, 0x74, 0x68, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x54, 0x79, 0x70, 0x65, + 0x73, 0x12, 0x2f, 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, + 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x53, + 0x74, 0x72, 0x69, 0x6e, 0x67, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, 0x69, 0x65, + 0x6c, 0x64, 0x12, 0x2c, 0x0a, 0x0b, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x5f, 0x66, 0x69, 0x65, 0x6c, + 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, 0x49, + 0x6e, 0x74, 0x33, 0x32, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x46, 0x69, 0x65, 0x6c, 0x64, + 0x12, 0x2c, 0x0a, 0x0b, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x03, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, 0x49, 0x6e, 0x74, + 0x36, 0x34, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x2c, + 0x0a, 0x0b, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x02, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, 0x46, 0x6c, 0x6f, 0x61, 0x74, + 0x52, 0x0a, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x2f, 0x0a, 0x0c, + 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x05, 0x20, 0x01, + 0x28, 0x01, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x44, 0x6f, 0x75, 0x62, 0x6c, 0x65, + 0x52, 0x0b, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x2f, 0x0a, + 0x0c, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x06, 0x20, + 0x01, 0x28, 0x0d, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x55, 0x69, 0x6e, 0x74, 0x33, + 0x32, 0x52, 0x0b, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x2f, + 0x0a, 0x0c, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x04, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x55, 0x69, 0x6e, 0x74, + 0x36, 0x34, 0x52, 0x0b, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x22, + 0x89, 0x01, 0x0a, 0x1a, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x57, 0x69, 0x74, 0x68, 0x4f, + 0x6e, 0x65, 0x4f, 0x66, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x54, 0x79, 0x70, 0x65, 0x12, 0x31, + 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x53, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x48, 0x00, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, 0x69, 0x65, 0x6c, + 0x64, 0x12, 0x2e, 0x0a, 0x0b, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, 0x49, 0x6e, + 0x74, 0x36, 0x34, 0x48, 0x00, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x46, 0x69, 0x65, 0x6c, + 0x64, 0x42, 0x08, 0x0a, 0x06, 0x6f, 0x6e, 0x65, 0x5f, 0x6f, 0x66, 0x22, 0x51, 0x0a, 0x1e, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x57, 0x69, 0x74, 0x68, 0x52, 0x65, 0x70, 0x65, 0x61, 0x74, + 0x65, 0x64, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x54, 0x79, 0x70, 0x65, 0x73, 0x12, 0x2f, 0x0a, + 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x09, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x53, 0x74, 0x72, 0x69, 0x6e, + 0x67, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x42, 0x2a, + 0x5a, 0x28, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, + 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x74, 0x65, 0x73, + 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, +} + +var ( + file_tests_message_message_casttypes_proto_rawDescOnce sync.Once + file_tests_message_message_casttypes_proto_rawDescData = file_tests_message_message_casttypes_proto_rawDesc +) + +func file_tests_message_message_casttypes_proto_rawDescGZIP() []byte { + file_tests_message_message_casttypes_proto_rawDescOnce.Do(func() { + file_tests_message_message_casttypes_proto_rawDescData = protoimpl.X.CompressGZIP(file_tests_message_message_casttypes_proto_rawDescData) + }) + return file_tests_message_message_casttypes_proto_rawDescData +} + +var file_tests_message_message_casttypes_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_tests_message_message_casttypes_proto_goTypes = []interface{}{ + (*MessageWithCustomTypes)(nil), // 0: tests.message.MessageWithCustomTypes + (*MessageWithOneOfCustomType)(nil), // 1: tests.message.MessageWithOneOfCustomType + (*MessageWithRepeatedCustomTypes)(nil), // 2: tests.message.MessageWithRepeatedCustomTypes +} +var file_tests_message_message_casttypes_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_tests_message_message_casttypes_proto_init() } +func file_tests_message_message_casttypes_proto_init() { + if File_tests_message_message_casttypes_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_tests_message_message_casttypes_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithCustomTypes); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_tests_message_message_casttypes_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithOneOfCustomType); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_tests_message_message_casttypes_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithRepeatedCustomTypes); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_tests_message_message_casttypes_proto_msgTypes[1].OneofWrappers = []interface{}{ + (*MessageWithOneOfCustomType_StringField)(nil), + (*MessageWithOneOfCustomType_Int64Field)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_tests_message_message_casttypes_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_tests_message_message_casttypes_proto_goTypes, + DependencyIndexes: file_tests_message_message_casttypes_proto_depIdxs, + MessageInfos: file_tests_message_message_casttypes_proto_msgTypes, + }.Build() + File_tests_message_message_casttypes_proto = out.File + file_tests_message_message_casttypes_proto_rawDesc = nil + file_tests_message_message_casttypes_proto_goTypes = nil + file_tests_message_message_casttypes_proto_depIdxs = nil +} diff --git a/tests/message/message_casttypes.proto b/tests/message/message_casttypes.proto new file mode 100644 index 0000000..24fb0e1 --- /dev/null +++ b/tests/message/message_casttypes.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package tests.message; + +import "patch/go.proto"; + +option go_package = "github.com/alta/protopatch/tests/message"; + + + +message MessageWithCustomTypes { + string string_field = 1 [(go.field).casttype = "String"]; + int32 int32_field = 2 [(go.field).casttype = "Int32"]; + int64 int64_field = 3 [(go.field).casttype = "Int64"]; + float float_field = 4 [(go.field).casttype = "Float"]; + double double_field = 5 [(go.field).casttype = "Double"]; + uint32 uint32_field = 6 [(go.field).casttype = "Uint32"]; + uint64 uint64_field = 7 [(go.field).casttype = "Uint64"]; +} + +message MessageWithOneOfCustomType { + oneof one_of { + string string_field = 1 [(go.field).casttype = "String"]; + int64 int64_field = 3 [(go.field).casttype = "Int64"]; + } +} diff --git a/tests/message/message_renames.proto b/tests/message/message_renames.proto index cb91203..0c5986e 100644 --- a/tests/message/message_renames.proto +++ b/tests/message/message_renames.proto @@ -6,6 +6,7 @@ import "patch/go.proto"; option go_package = "github.com/alta/protopatch/tests/message"; + message Francis { option (go.message).name = 'Frank'; } @@ -40,5 +41,5 @@ message MessageWithEmbeddedField { } message Embedded { - string message = 1; + string message = 1; } diff --git a/tests/message/message_test.go b/tests/message/message_test.go index 689d871..ca0a15c 100644 --- a/tests/message/message_test.go +++ b/tests/message/message_test.go @@ -3,6 +3,7 @@ package message import ( "testing" + "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "github.com/alta/protopatch/tests" @@ -55,13 +56,14 @@ func TestRenamedInnerMessage(t *testing.T) { } func TestMessageWithRenamedField(t *testing.T) { - m := &MessageWithRenamedField{} + m := &MessageWithRenamedField{ + ID: 66, + } tests.ValidateMessage(t, m) var _ int32 = m.ID var _ int32 = m.GetID() } - func TestMessageWithEmbeddedFields(t *testing.T) { message := "noop" m := &MessageWithEmbeddedField{ @@ -107,3 +109,32 @@ func TestExtendedMessage(t *testing.T) { _ = proto.GetExtension(m, ExtGamma).(string) _ = proto.GetExtension(m, ExtDelta).(string) } + +func TestMessageWithCustomTypes(t *testing.T) { + m := &MessageWithCustomTypes{ + StringField: "42", + Int32Field: 42, + Int64Field: 42, + FloatField: 42, + DoubleField: 42, + Uint32Field: 42, + Uint64Field: 42, + } + + tests.ValidateMessage(t, m) + var _ string = string(m.StringField) + var _ int32 = int32(m.Int32Field) + var _ int64 = int64(m.Int64Field) + var _ float32 = float32(m.FloatField) + var _ float64 = float64(m.DoubleField) + var _ uint32 = uint32(m.Uint32Field) + var _ uint64 = uint64(m.Uint64Field) + + assert.Equal(t, String("42"), m.StringField) + assert.Equal(t, Int32(42), m.Int32Field) + assert.Equal(t, Int64(42), m.Int64Field) + assert.Equal(t, Float(42), m.FloatField) + assert.Equal(t, Double(42), m.DoubleField) + assert.Equal(t, Uint32(42), m.Uint32Field) + assert.Equal(t, Uint64(42), m.Uint64Field) +} diff --git a/tests/plugin/validate.pb.go b/tests/plugin/validate.pb.go index 7ca9a86..896a146 100644 --- a/tests/plugin/validate.pb.go +++ b/tests/plugin/validate.pb.go @@ -8,6 +8,7 @@ package plugin import ( _ "github.com/alta/protopatch/patch/gopb" + github_com_alta_protopatch_tests_message "github.com/alta/protopatch/tests/message" _ "github.com/envoyproxy/protoc-gen-validate/validate" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" @@ -76,9 +77,9 @@ type Interface struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - Status InterfaceStatus `protobuf:"varint,2,opt,name=status,proto3,enum=tests.plugin.Interface_Status" json:"status,omitempty"` - Addresses []*IPAddress `protobuf:"bytes,3,rep,name=addresses,proto3" json:"addresses,omitempty"` + Name github_com_alta_protopatch_tests_message.Name `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Status InterfaceStatus `protobuf:"varint,2,opt,name=status,proto3,enum=tests.plugin.Interface_Status" json:"status,omitempty"` + Addresses []*IPAddress `protobuf:"bytes,3,rep,name=addresses,proto3" json:"addresses,omitempty"` } func (x *Interface) Reset() { @@ -113,7 +114,7 @@ func (*Interface) Descriptor() ([]byte, []int) { return file_tests_plugin_validate_proto_rawDescGZIP(), []int{0} } -func (x *Interface) GetName() string { +func (x *Interface) GetName() github_com_alta_protopatch_tests_message.Name { if x != nil { return x.Name } @@ -222,34 +223,39 @@ var file_tests_plugin_validate_proto_rawDesc = []byte{ 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x1a, 0x0e, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x91, 0x02, 0x0a, 0x09, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, - 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1e, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, - 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x2e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x0a, 0xfa, 0x42, 0x07, 0x82, 0x01, 0x04, 0x10, 0x01, - 0x20, 0x00, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x35, 0x0a, 0x09, 0x61, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x50, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, - 0x73, 0x22, 0x75, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x20, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x1a, 0x13, 0xca, 0xb5, 0x03, 0x0f, 0x0a, 0x0d, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x12, 0x16, 0x0a, - 0x02, 0x55, 0x50, 0x10, 0x01, 0x1a, 0x0e, 0xca, 0xb5, 0x03, 0x0a, 0x0a, 0x08, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x55, 0x70, 0x12, 0x1a, 0x0a, 0x04, 0x44, 0x4f, 0x57, 0x4e, 0x10, 0x02, 0x1a, - 0x10, 0xca, 0xb5, 0x03, 0x0c, 0x0a, 0x0a, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x44, 0x6f, 0x77, - 0x6e, 0x1a, 0x15, 0xca, 0xb5, 0x03, 0x11, 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, - 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x69, 0x0a, 0x09, 0x49, 0x50, 0x41, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x27, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x42, 0x11, 0xca, 0xb5, 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x34, 0xfa, - 0x42, 0x04, 0x72, 0x02, 0x78, 0x01, 0x48, 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x28, - 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x42, 0x12, 0xca, 0xb5, - 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x36, 0xfa, 0x42, 0x05, 0x72, 0x03, 0x80, 0x01, 0x01, - 0x48, 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x36, 0x42, 0x09, 0x0a, 0x07, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, - 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe0, 0x02, 0x0a, 0x09, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, + 0x63, 0x65, 0x12, 0x61, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x42, 0x4d, 0xca, 0xb5, 0x03, 0x2f, 0x1a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, + 0x63, 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0xfa, 0x42, 0x17, 0x72, 0x15, 0x10, 0x02, 0x18, 0x0a, 0x32, 0x0f, + 0x5b, 0x30, 0x2d, 0x39, 0x61, 0x2d, 0x7a, 0x41, 0x2d, 0x5a, 0x2e, 0x2d, 0x5f, 0x5d, 0x2a, 0x52, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1e, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x2e, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x0a, 0xfa, 0x42, 0x07, 0x82, 0x01, 0x04, 0x10, 0x01, 0x20, + 0x00, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x35, 0x0a, 0x09, 0x61, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x74, + 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x50, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, + 0x22, 0x75, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x20, 0x0a, 0x07, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x1a, 0x13, 0xca, 0xb5, 0x03, 0x0f, 0x0a, 0x0d, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x12, 0x16, 0x0a, 0x02, + 0x55, 0x50, 0x10, 0x01, 0x1a, 0x0e, 0xca, 0xb5, 0x03, 0x0a, 0x0a, 0x08, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x55, 0x70, 0x12, 0x1a, 0x0a, 0x04, 0x44, 0x4f, 0x57, 0x4e, 0x10, 0x02, 0x1a, 0x10, + 0xca, 0xb5, 0x03, 0x0c, 0x0a, 0x0a, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x44, 0x6f, 0x77, 0x6e, + 0x1a, 0x15, 0xca, 0xb5, 0x03, 0x11, 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x69, 0x0a, 0x09, 0x49, 0x50, 0x41, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x12, 0x27, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x42, 0x11, 0xca, 0xb5, 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x34, 0xfa, 0x42, + 0x04, 0x72, 0x02, 0x78, 0x01, 0x48, 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x28, 0x0a, + 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x42, 0x12, 0xca, 0xb5, 0x03, + 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x36, 0xfa, 0x42, 0x05, 0x72, 0x03, 0x80, 0x01, 0x01, 0x48, + 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x36, 0x42, 0x09, 0x0a, 0x07, 0x41, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, + 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/tests/plugin/validate.pb.validate.go b/tests/plugin/validate.pb.validate.go index 4ba1073..0b374b6 100644 --- a/tests/plugin/validate.pb.validate.go +++ b/tests/plugin/validate.pb.validate.go @@ -40,7 +40,19 @@ func (m *Interface) Validate() error { return nil } - // no validation rules for Name + if l := utf8.RuneCountInString(string(m.GetName())); l < 2 || l > 10 { + return InterfaceValidationError{ + field: "Name", + reason: "value length must be between 2 and 10 runes, inclusive", + } + } + + if !_Interface_Name_Pattern.MatchString(string(m.GetName())) { + return InterfaceValidationError{ + field: "Name", + reason: "value does not match regex pattern \"[0-9a-zA-Z.-_]*\"", + } + } if _, ok := _Interface_Status_NotInLookup[m.GetStatus()]; ok { return InterfaceValidationError{ @@ -128,6 +140,8 @@ var _ interface { ErrorName() string } = InterfaceValidationError{} +var _Interface_Name_Pattern = regexp.MustCompile("[0-9a-zA-Z.-_]*") + var _Interface_Status_NotInLookup = map[InterfaceStatus]struct{}{ 0: {}, } diff --git a/tests/plugin/validate.proto b/tests/plugin/validate.proto index 0dd5129..a26b6a9 100644 --- a/tests/plugin/validate.proto +++ b/tests/plugin/validate.proto @@ -8,7 +8,7 @@ import "validate/validate.proto"; option go_package = "github.com/alta/protopatch/tests/plugin"; message Interface { - string name = 1; + string name = 1 [(go.field).casttype = "github.com/alta/protopatch/tests/message.Name", (validate.rules).string = {min_len: 2, max_len: 10, pattern: "[0-9a-zA-Z.-_]*"}]; enum Status { option (go.enum).name = "InterfaceStatus"; UNKNOWN = 0 [(go.value).name = "StatusUnknown"]; diff --git a/tests/plugin/validate_test.go b/tests/plugin/validate_test.go index b502df7..869b0dd 100644 --- a/tests/plugin/validate_test.go +++ b/tests/plugin/validate_test.go @@ -33,9 +33,10 @@ func TestInterfaceValidate(t *testing.T) { wantErr bool }{ {"nil", nil, false}, // Weird, but OK - {"unknown", &Interface{Status: StatusUnknown}, true}, - {"up", &Interface{Status: StatusUp, Addresses: nil}, false}, - {"down", &Interface{Status: StatusDown, Addresses: nil}, false}, + {"unknown", &Interface{Name: "eth0", Status: StatusUnknown}, true}, + {"up", &Interface{Name: "eth0", Status: StatusUp, Addresses: nil}, false}, + {"down", &Interface{Name: "eth0", Status: StatusDown, Addresses: nil}, false}, + {"invalid name", &Interface{Name: "a", Status: StatusDown, Addresses: nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {