diff --git a/go.mod b/go.mod index b87cf7f..0129901 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/fatih/structtag v1.2.0 github.com/iancoleman/strcase v0.1.2 // indirect github.com/lyft/protoc-gen-star v0.5.2 // indirect + github.com/stretchr/testify v1.7.0 golang.org/x/tools v0.1.6 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index bbec383..b4724f2 100644 --- a/go.sum +++ b/go.sum @@ -34,8 +34,9 @@ github.com/spf13/afero v1.3.4 h1:8q6vk3hthlpb2SouZcnBVKboxWQWMDNF38bwholZrJc= github.com/spf13/afero v1.3.4/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= diff --git a/patch/field_type.go b/patch/field_type.go new file mode 100644 index 0000000..8e3dad6 --- /dev/null +++ b/patch/field_type.go @@ -0,0 +1,202 @@ +package patch + +import ( + "go/ast" + "go/types" + "log" + "strings" + + "golang.org/x/tools/go/ast/astutil" +) + +func (p *Patcher) patchTypeDef(id *ast.Ident, obj types.Object) { + fieldType, ok := p.fieldTypes[obj] + if !ok { + return + } + + parent := p.findParentNode(id) + if pkg, name, isSlice := packageAndName(fieldType); pkg != "" { + f := p.fileOf(id) + pkgImport := packageImport(pkg) + astutil.AddNamedImport(p.fset, f, pkgImport, pkg) + fieldType = pkgImport + "." + name + if isSlice { + fieldType = "[]" + fieldType + } + } + castDecl := func(v *ast.Field) bool { + switch t := v.Type.(type) { + case *ast.Ident: + t.Name = fieldType + return true + case *ast.ArrayType: + if isSliceType(fieldType) { + if id, ok := t.Elt.(*ast.Ident); ok { + id.Name = strings.TrimPrefix(fieldType, "[]") + return true + } + } else { + v.Type = &ast.Ident{ + Name: fieldType, + } + return true + } + return false + default: + return false + } + } + + // Cast Field definition + if id.Obj != nil && id.Obj.Decl != nil { + v, ok := id.Obj.Decl.(*ast.Field) + if !ok { + log.Printf("Warning: fieldType declared for non-field object: %v `%s`", obj, fieldType) + return + } + if !castDecl(v) { + log.Printf("Warning: unsupported fieldType type: %T `%s`", v.Type, fieldType) + } + return + } + switch obj.Type().(type) { + // Cast Getter signature + case *types.Signature: + n, ok := parent.(*ast.FuncDecl) + if !ok { + log.Printf("Warning: unexpected type for getter: %v `%T`", obj, parent) + break + } + if l := len(n.Type.Results.List); l != 1 { + log.Printf("Warning: unexpected return count for getter: %v `%d`", obj, l) + return + } + if !castDecl(n.Type.Results.List[0]) { + log.Printf("Warning: unsupported fieldType type: %T `%s`", n.Type.Results.List[0].Type, fieldType) + } + return + } +} + +func (p *Patcher) patchTypeUsage(id *ast.Ident, obj types.Object) { + desiredType, ok := p.fieldTypes[obj] + if !ok { + return + } + var originalType string + switch t := obj.Type().(type) { + case *types.Basic: + originalType = t.Name() + case *types.Signature: + if t.Results().Len() != 1 { + return + } + originalType = t.Results().At(0).Type().String() + } + usageNode := p.findParentNode(id) + pkgPath, pkgName, isSlice := packageAndName(desiredType) + pkgImport := packageImport(pkgPath) + if pkgPath != "" { + desiredType = pkgImport + "." + pkgName + if isSlice { + desiredType = "[]" + desiredType + } + } + cast := func(as string, expr ast.Expr) ast.Expr { + if pkgPath != "" && as == desiredType { + f := p.fileOf(id) + // astutil.AddNamedImport already check for duplicated imports, so there is no need to do it here + astutil.AddNamedImport(p.fset, f, pkgImport, pkgPath) + } + return &ast.CallExpr{ + Fun: &ast.Ident{ + Name: as, + }, + Args: []ast.Expr{expr}, + } + } + parentNode := p.findParentNode(usageNode) + + switch usage := usageNode.(type) { + case *ast.SelectorExpr: + switch parentExpr := parentNode.(type) { + case *ast.AssignStmt: + if len(parentExpr.Lhs) != len(parentExpr.Rhs) { + return + } + for i := range parentExpr.Lhs { + if parentExpr.Lhs[i] == usage { + parentExpr.Rhs[i] = cast(desiredType, parentExpr.Rhs[i]) + return + } + } + for i := range parentExpr.Rhs { + if parentExpr.Rhs[i] == usage { + parentExpr.Rhs[i] = cast(originalType, parentExpr.Rhs[i]) + return + } + } + case *ast.CallExpr: + parent := p.findParentNode(parentExpr) + assign, isAssign := parent.(*ast.AssignStmt) + if parentExpr.Fun == usage && isAssign { + for i := range assign.Rhs { + if assign.Rhs[i] == parentExpr { + assign.Rhs[i] = cast(originalType, assign.Rhs[i]) + return + } + } + } + call, isCall := parent.(*ast.CallExpr) + if isCall { + for i := range call.Args { + if call.Args[i] == parentExpr { + call.Args[i] = cast(originalType, call.Args[i]) + return + } + } + } + for i, v := range parentExpr.Args { + if v == usage { + parentExpr.Args[i] = cast(originalType, usage) + return + } + } + case *ast.BinaryExpr: + if parentExpr.X == usage { + parentExpr.X = cast(originalType, parentExpr.X) + } + if parentExpr.Y == usage { + parentExpr.Y = cast(originalType, parentExpr.Y) + } + } + case *ast.KeyValueExpr: + if usage.Key == id { + usage.Value = cast(desiredType, usage.Value) + return + } + if usage.Value == id { + usage.Value = cast(originalType, usage.Value) + return + } + } +} + +func packageAndName(fqn string) (pkg string, name string, isSlice bool) { + isSlice = isSliceType(fqn) + fqn = strings.TrimPrefix(fqn, "[]") + parts := strings.Split(fqn, ".") + if len(parts) == 1 { + return "", fqn, isSlice + } + return strings.Join(parts[:len(parts)-1], "."), parts[len(parts)-1], isSlice +} + +func isSliceType(typeName string) bool { + return strings.HasPrefix(typeName, "[]") +} + +func packageImport(pkg string) string { + return strings.Replace(strings.Replace(pkg, "/", "_", -1), ".", "_", -1) +} diff --git a/patch/field_type_test.go b/patch/field_type_test.go new file mode 100644 index 0000000..9afded4 --- /dev/null +++ b/patch/field_type_test.go @@ -0,0 +1,245 @@ +package patch + +import ( + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/compiler/protogen" +) + +const ( + srcDef = `package foo + +import ( + "fmt" +) + +type String string + +type Message struct { + Content string +} + +func (m *Message) GetContent() string { + if m != nil { + return m.Content + } + return "" +} + +func print(s string) { + fmt.Println(s) +} +` + wantDef = `package foo + +import ( + "fmt" +) + +type String string + +type Message struct { + Content String +} + +func (m *Message) GetContent() String { + if m != nil { + return m.Content + } + return "" +} + +func print(s string) { + fmt.Println(s) +} +` +) + +const ( + fileName = "foo.go" + packageName = "foo" + fieldName = "Content" + msgName = "Message" + fieldType = "String" +) + +func prepareCastType(src string) (*Patcher, *ast.File, error) { + p, err := NewPatcher(&protogen.Plugin{}) + if err != nil { + return nil, nil, err + } + p.filesByName = make(map[string]*ast.File) + p.info = &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + p.fset = token.NewFileSet() + file, err := parser.ParseFile(p.fset, fileName, src, parser.ParseComments) + if err != nil { + return nil, nil, err + } + pkg := NewPackage(packageName, packageName) + if err := pkg.AddFile(fileName, file); err != nil { + return nil, nil, err + } + if err := pkg.Check(basicImporter{p}, p.fset, p.info); err != nil { + return nil, nil, err + } + p.filesByName[fileName] = file + p.packagesByPath[packageName] = pkg + p.packagesByName[packageName] = pkg + p.packages = append(p.packages, pkg) + p.Type(protogen.GoIdent{GoName: msgName + "." + fieldName, GoImportPath: packageName}, fieldType) + p.Type(protogen.GoIdent{GoName: msgName + "." + "Get" + fieldName, GoImportPath: packageName}, fieldType) + // Map cast types + for id, typ := range p.types { + obj, _ := p.find(id) + if obj == nil { + continue + } + p.fieldTypes[obj] = typ + } + return p, file, nil +} + +func TestScalarCastType(t *testing.T) { + tests := []struct{ + name string + src string + want string + }{ + { + name: "cast definition", + src: srcDef, + want: wantDef, + }, + { + name: "cast field initialization", + src: srcDef+` +func useContent() { + s := "ok" + msg := &Message{Content: s} +} +`, + want: wantDef+` +func useContent() { + s := "ok" + msg := &Message{Content: String(s)} +} +`, + }, + { + name: "cast field assignation", + src: srcDef+` +func useContent() { + s := "ok" + msg := &Message{} + msg.Content = s +} +`, + want: wantDef+` +func useContent() { + s := "ok" + msg := &Message{} + msg.Content = String(s) +} +`, + }, + { + name: "cast field assignation usage", + src: srcDef+` +func useContent() { + var s string + msg := &Message{Content: "ok"} + s = msg.Content +} +`, + want: wantDef+` +func useContent() { + var s string + msg := &Message{Content: String("ok")} + s = string(msg.Content) +} +`, + }, + { + name: "cast field used as argument", + src: srcDef+` +func useContent() { + msg := &Message{Content: "ok"} + var s string + s = msg.Content +} +`, + want: wantDef+` +func useContent() { + msg := &Message{Content: String("ok")} + var s string + s = string(msg.Content) +} +`, + }, + { + name: "cast full code", + src: srcDef+` +func useContent() { + s := "ok" + msg := &Message{Content: s} + print(msg.Content) + print(msg.GetContent()) + msg.Content = s + useContentGetter(msg) +} + +func useContentGetter(msg *Message) { + var s string + s = msg.GetContent() + s = msg.Content + s = msg.Content + "..." + print(s) +} +`, + want: wantDef+` +func useContent() { + s := "ok" + msg := &Message{Content: String(s)} + print(string(msg.Content)) + print(string(msg.GetContent())) + msg.Content = String(s) + useContentGetter(msg) +} + +func useContentGetter(msg *Message) { + var s string + s = string(msg.GetContent()) + s = string(msg.Content) + s = string(msg.Content) + "..." + print(s) +} +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p, file, err := prepareCastType(tt.src) + if err != nil { + t.Errorf("failed to initilialize test") + return + } + + if err := p.patchGoFiles(); err != nil { + t.Fatal(err) + } + got := p.nodeToString(file) + assert.Equal(t, tt.want, got) + }) + } + +} diff --git a/patch/go.proto b/patch/go.proto index 31c5675..6bafb6f 100644 --- a/patch/go.proto +++ b/patch/go.proto @@ -21,6 +21,12 @@ message Options { // See https://golang.org/ref/spec#Struct_types. optional bool embed = 2; + // The type option changes the generated field type. + // All generated code assumes that this type is castable to the protocol buffer field type, + // so it does not work for messages types. + // Repeated fields are also supported, both as '[]Type' or 'Types' where Types is a named slice type. + optional string type = 3; + // The getter option renames the generated getter method (default: Get) // so a custom getter can be implemented in its place. optional string getter = 10; // TODO: implement this diff --git a/patch/gopb/go.pb.go b/patch/gopb/go.pb.go index 45dc636..5f28727 100644 --- a/patch/gopb/go.pb.go +++ b/patch/gopb/go.pb.go @@ -38,6 +38,11 @@ type Options struct { // Only message types can be embedded. Oneof fields cannot be embedded. // See https://golang.org/ref/spec#Struct_types. Embed *bool `protobuf:"varint,2,opt,name=embed" json:"embed,omitempty"` + // The type option changes the generated field type. + // All generated code assumes that this type is castable to the protocol buffer field type, + // so it does not work for messages types. + // Repeated fields are also supported, both as '[]Type' or 'Types' where Types is a named slice type. + Type *string `protobuf:"bytes,3,opt,name=type" json:"type,omitempty"` // The getter option renames the generated getter method (default: Get) // so a custom getter can be implemented in its place. Getter *string `protobuf:"bytes,10,opt,name=getter" json:"getter,omitempty"` // TODO: implement this @@ -99,6 +104,13 @@ func (x *Options) GetEmbed() bool { return false } +func (x *Options) GetType() string { + if x != nil && x.Type != nil { + return *x.Type + } + return "" +} + func (x *Options) GetGetter() string { if x != nil && x.Getter != nil { return *x.Getter @@ -328,58 +340,59 @@ var file_patch_go_proto_rawDesc = []byte{ 0x0a, 0x0e, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x67, 0x6f, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa0, 0x01, 0x0a, 0x07, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb4, 0x01, 0x0a, 0x07, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, - 0x67, 0x65, 0x74, 0x74, 0x65, 0x72, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x67, 0x65, - 0x74, 0x74, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, 0x18, 0x14, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x65, 0x72, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x74, 0x72, 0x69, - 0x6e, 0x67, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x74, 0x72, - 0x69, 0x6e, 0x67, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0xc3, 0x01, 0x0a, 0x0b, 0x4c, 0x69, - 0x6e, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, - 0x14, 0x0a, 0x05, 0x65, 0x6e, 0x75, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, - 0x65, 0x6e, 0x75, 0x6d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x12, 0x1e, 0x0a, - 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x20, 0x0a, - 0x0b, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, 0x18, 0x0a, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x0b, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, 0x3a, - 0x47, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, - 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, - 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x3a, 0x41, 0x0a, 0x05, 0x66, 0x69, 0x65, 0x6c, - 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x3a, 0x41, 0x0a, 0x05, 0x6f, - 0x6e, 0x65, 0x6f, 0x66, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, - 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x3a, 0x3e, - 0x0a, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x12, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x4f, 0x70, 0x74, + 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x12, 0x12, 0x0a, 0x04, + 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, + 0x12, 0x16, 0x0a, 0x06, 0x67, 0x65, 0x74, 0x74, 0x65, 0x72, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x67, 0x65, 0x74, 0x74, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x61, 0x67, 0x73, + 0x18, 0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, 0x12, 0x1a, 0x0a, 0x08, + 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x12, 0x23, 0x0a, 0x0d, 0x73, 0x74, 0x72, 0x69, + 0x6e, 0x67, 0x65, 0x72, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x65, 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0xc3, 0x01, + 0x0a, 0x0b, 0x4c, 0x69, 0x6e, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x12, + 0x1a, 0x0a, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x08, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x66, + 0x69, 0x65, 0x6c, 0x64, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x66, 0x69, 0x65, + 0x6c, 0x64, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6e, 0x75, 0x6d, 0x73, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x05, 0x65, 0x6e, 0x75, 0x6d, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, + 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, 0x73, 0x6d, 0x73, + 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x69, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x6c, 0x69, + 0x73, 0x6d, 0x73, 0x3a, 0x47, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1f, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x3a, 0x41, 0x0a, 0x05, + 0x66, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, - 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x3a, 0x45, - 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, - 0x6c, 0x75, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, - 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x42, 0x0a, 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x12, 0x1c, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x46, 0x69, 0x6c, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x67, 0x6f, 0x2e, 0x4c, 0x69, 0x6e, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x42, 0x27, 0x5a, 0x25, 0x67, 0x69, 0x74, - 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, - 0x70, 0x62, + 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x3a, + 0x41, 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, + 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x05, 0x6f, 0x6e, 0x65, + 0x6f, 0x66, 0x3a, 0x3e, 0x0a, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x12, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, + 0x6d, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x65, 0x6e, + 0x75, 0x6d, 0x3a, 0x45, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, + 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd9, + 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0b, 0x2e, 0x67, 0x6f, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x73, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x42, 0x0a, 0x04, 0x6c, 0x69, 0x6e, + 0x74, 0x12, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, + 0xd9, 0x36, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x67, 0x6f, 0x2e, 0x4c, 0x69, 0x6e, 0x74, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x04, 0x6c, 0x69, 0x6e, 0x74, 0x42, 0x27, 0x5a, + 0x25, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, + 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x70, 0x61, 0x74, 0x63, + 0x68, 0x2f, 0x67, 0x6f, 0x70, 0x62, } var ( diff --git a/patch/patcher.go b/patch/patcher.go index 2c70f95..026bf91 100644 --- a/patch/patcher.go +++ b/patch/patcher.go @@ -6,6 +6,7 @@ import ( "go/ast" "go/format" "go/parser" + "go/printer" "go/token" "go/types" "log" @@ -47,8 +48,10 @@ type Patcher struct { objectRenames map[types.Object]string tags map[protogen.GoIdent]string fieldTags map[types.Object]string - embeds map[protogen.GoIdent]string - fieldEmbeds map[types.Object]string + embeds map[protogen.GoIdent]string + types map[protogen.GoIdent]string + fieldEmbeds map[types.Object]string + fieldTypes map[types.Object]string } // NewPatcher returns an initialized Patcher for gen. @@ -57,16 +60,18 @@ func NewPatcher(gen *protogen.Plugin) (*Patcher, error) { gen: gen, packagesByPath: make(map[string]*Package), packagesByName: make(map[string]*Package), - renames: make(map[protogen.GoIdent]string), - typeRenames: make(map[protogen.GoIdent]string), - valueRenames: make(map[protogen.GoIdent]string), - fieldRenames: make(map[protogen.GoIdent]string), - methodRenames: make(map[protogen.GoIdent]string), - objectRenames: make(map[types.Object]string), - tags: make(map[protogen.GoIdent]string), - fieldTags: make(map[types.Object]string), - embeds: make(map[protogen.GoIdent]string), - fieldEmbeds: make(map[types.Object]string), + renames: make(map[protogen.GoIdent]string), + typeRenames: make(map[protogen.GoIdent]string), + valueRenames: make(map[protogen.GoIdent]string), + fieldRenames: make(map[protogen.GoIdent]string), + methodRenames: make(map[protogen.GoIdent]string), + objectRenames: make(map[types.Object]string), + tags: make(map[protogen.GoIdent]string), + fieldTags: make(map[types.Object]string), + embeds: make(map[protogen.GoIdent]string), + fieldEmbeds: make(map[types.Object]string), + types: make(map[protogen.GoIdent]string), + fieldTypes: make(map[types.Object]string), } return p, p.scan() } @@ -310,6 +315,20 @@ func (p *Patcher) scanField(f *protogen.Field) { p.RenameMethod(ident.WithChild(m.GoIdent, "Get"+f.GoName), "Get"+newName) // Getter } + // check type + if fieldType := opts.GetType(); fieldType != "" { + switch { + case f.Message != nil && !f.Desc.IsList(): + log.Printf("Warning: type declared for message field: %s", f.Desc.Name()) + case f.Oneof != nil: + p.Type(ident.WithChild(f.GoIdent, f.GoName), fieldType) + p.Type(ident.WithChild(m.GoIdent, "Get"+f.GoName), fieldType) + default: + p.Type(ident.WithChild(m.GoIdent, f.GoName), fieldType) + p.Type(ident.WithChild(m.GoIdent, "Get"+f.GoName), fieldType) + } + } + // Add or replace any struct tags? tags := opts.GetTags() if tags != "" { @@ -397,6 +416,14 @@ func (p *Patcher) nameFor(id protogen.GoIdent) string { return ident.LeafName(id) } +// Type casts the Go struct field as the desired type +// The typeName value must be a named type, e.g.: "type String string" +// It can also be a fully qualified type name, e.g.: "go.repo.com/mymodule/types.String" +func (p *Patcher) Type(id protogen.GoIdent, typeName string) { + p.types[id] = typeName + log.Printf("Cast type:\t%s.%s → %s", id.GoImportPath, id.GoName, typeName) +} + // Tag adds the specified struct tags to the field specified by selector, // in the form of "Message.Field". The tags argument should omit outer backticks (`). // The value of id.GoName should be the original generated identifier name, not a renamed identifier. @@ -540,6 +567,15 @@ func (p *Patcher) checkGoFiles() error { } } + // Map cast types + for id, typ := range p.types { + obj, _ := p.find(id) + if obj == nil { + continue + } + p.fieldTypes[obj] = typ + } + // Map struct tags. for id, tags := range p.tags { obj, _ := p.find(id) @@ -676,6 +712,7 @@ func (p *Patcher) serializeGoFiles(res *pluginpb.CodeGeneratorResponse) error { func (p *Patcher) patchGoFiles() error { log.Printf("\nDefs") for id, obj := range p.info.Defs { + p.patchTypeDef(id, obj) p.patchIdent(id, obj, true) p.patchTags(id, obj) // if id.IsExported() { @@ -686,6 +723,7 @@ func (p *Patcher) patchGoFiles() error { log.Printf("\nUses\n") for id, obj := range p.info.Uses { + p.patchTypeUsage(id, obj) p.patchIdent(id, obj, false) } @@ -715,6 +753,26 @@ func (p *Patcher) patchIdent(id *ast.Ident, obj types.Object, isDecl bool) { } } +func (p *Patcher) nodeToString(n ast.Node) string { + b := &bytes.Buffer{} + if err := printer.Fprint(b, p.fset, n); err != nil { + log.Fatal(err) + } + return b.String() +} + +func (p *Patcher) findParentNode(id ast.Node) ast.Node { + var node ast.Node + astutil.Apply(p.fileOf(id), func(cursor *astutil.Cursor) bool { + if cursor.Node() == id { + node = cursor.Parent() + return false + } + return true + }, nil) + return node +} + func (p *Patcher) patchTags(id *ast.Ident, obj types.Object) { fieldTags := p.fieldTags[obj] if fieldTags == "" || id.Obj == nil { @@ -768,11 +826,7 @@ func (p *Patcher) patchComments(id *ast.Ident, repl string) { // Borrowed from https://github.com/golang/tools/blob/HEAD/refactor/rename/rename.go#L543 func (p *Patcher) findCommentGroups(id *ast.Ident) (doc *ast.CommentGroup, comment *ast.CommentGroup) { - tf := p.fset.File(id.Pos()) - if tf == nil { - return - } - f := p.filesByName[tf.Name()] + f := p.fileOf(id) if f == nil { return } @@ -802,6 +856,14 @@ func (p *Patcher) findCommentGroups(id *ast.Ident) (doc *ast.CommentGroup, comme return } +func (p *Patcher) fileOf(node ast.Node) *ast.File { + tf := p.fset.File(node.Pos()) + if tf == nil { + return nil + } + return p.filesByName[tf.Name()] +} + func patchCommentGroup(c *ast.CommentGroup, x *regexp.Regexp, repl string) { if c == nil { return @@ -834,3 +896,4 @@ func typeString(obj types.Object) string { } return obj.Type().String() } + diff --git a/tests/helpers.go b/tests/helpers.go index 4e7b89b..4bb6687 100644 --- a/tests/helpers.go +++ b/tests/helpers.go @@ -26,6 +26,9 @@ func ValidateEnum(t *testing.T, e protoreflect.Enum, names EnumNames, values Enu // ValidateMessage performs basic validation of a message. func ValidateMessage(t *testing.T, m proto.Message) { // TODO: add some validation + if m == nil { + return + } b, err := proto.Marshal(m) if err != nil { t.Errorf("failed to marshal message: %v", err) diff --git a/tests/message/message.extensions.go b/tests/message/message.extensions.go new file mode 100644 index 0000000..1fe14f0 --- /dev/null +++ b/tests/message/message.extensions.go @@ -0,0 +1,20 @@ +package message + +type Name string +type Names []string + +type Int32 int32 + +type Int64 int64 + +type String string + +type Float float32 + +type Double float64 + +type Uint32 uint32 + +type Uint64 uint64 + +type Strings []string diff --git a/tests/message/message_field_types.pb.go b/tests/message/message_field_types.pb.go new file mode 100644 index 0000000..6858fe3 --- /dev/null +++ b/tests/message/message_field_types.pb.go @@ -0,0 +1,454 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.17.3 +// source: tests/message/message_field_types.proto + +package message + +import ( + _ "github.com/alta/protopatch/patch/gopb" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type MessageWithCustomTypes struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StringField String `protobuf:"bytes,1,opt,name=string_field,json=stringField,proto3" json:"string_field,omitempty"` + Int32Field Int32 `protobuf:"varint,2,opt,name=int32_field,json=int32Field,proto3" json:"int32_field,omitempty"` + Int64Field Int64 `protobuf:"varint,3,opt,name=int64_field,json=int64Field,proto3" json:"int64_field,omitempty"` + FloatField Float `protobuf:"fixed32,4,opt,name=float_field,json=floatField,proto3" json:"float_field,omitempty"` + DoubleField Double `protobuf:"fixed64,5,opt,name=double_field,json=doubleField,proto3" json:"double_field,omitempty"` + Uint32Field Uint32 `protobuf:"varint,6,opt,name=uint32_field,json=uint32Field,proto3" json:"uint32_field,omitempty"` + Uint64Field Uint64 `protobuf:"varint,7,opt,name=uint64_field,json=uint64Field,proto3" json:"uint64_field,omitempty"` +} + +func (x *MessageWithCustomTypes) Reset() { + *x = MessageWithCustomTypes{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_field_types_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithCustomTypes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithCustomTypes) ProtoMessage() {} + +func (x *MessageWithCustomTypes) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_field_types_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithCustomTypes.ProtoReflect.Descriptor instead. +func (*MessageWithCustomTypes) Descriptor() ([]byte, []int) { + return file_tests_message_message_field_types_proto_rawDescGZIP(), []int{0} +} + +func (x *MessageWithCustomTypes) GetStringField() String { + if x != nil { + return x.StringField + } + return "" +} + +func (x *MessageWithCustomTypes) GetInt32Field() Int32 { + if x != nil { + return x.Int32Field + } + return 0 +} + +func (x *MessageWithCustomTypes) GetInt64Field() Int64 { + if x != nil { + return x.Int64Field + } + return 0 +} + +func (x *MessageWithCustomTypes) GetFloatField() Float { + if x != nil { + return x.FloatField + } + return 0 +} + +func (x *MessageWithCustomTypes) GetDoubleField() Double { + if x != nil { + return x.DoubleField + } + return 0 +} + +func (x *MessageWithCustomTypes) GetUint32Field() Uint32 { + if x != nil { + return x.Uint32Field + } + return 0 +} + +func (x *MessageWithCustomTypes) GetUint64Field() Uint64 { + if x != nil { + return x.Uint64Field + } + return 0 +} + +type MessageWithOneOfCustomType struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to OneOf: + // *MessageWithOneOfCustomType_StringField + // *MessageWithOneOfCustomType_Int64Field + OneOf isMessageWithOneOfCustomType_OneOf `protobuf_oneof:"one_of"` +} + +func (x *MessageWithOneOfCustomType) Reset() { + *x = MessageWithOneOfCustomType{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_field_types_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithOneOfCustomType) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithOneOfCustomType) ProtoMessage() {} + +func (x *MessageWithOneOfCustomType) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_field_types_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithOneOfCustomType.ProtoReflect.Descriptor instead. +func (*MessageWithOneOfCustomType) Descriptor() ([]byte, []int) { + return file_tests_message_message_field_types_proto_rawDescGZIP(), []int{1} +} + +func (m *MessageWithOneOfCustomType) GetOneOf() isMessageWithOneOfCustomType_OneOf { + if m != nil { + return m.OneOf + } + return nil +} + +func (x *MessageWithOneOfCustomType) GetStringField() String { + if x, ok := x.GetOneOf().(*MessageWithOneOfCustomType_StringField); ok { + return x.StringField + } + return "" +} + +func (x *MessageWithOneOfCustomType) GetInt64Field() Int64 { + if x, ok := x.GetOneOf().(*MessageWithOneOfCustomType_Int64Field); ok { + return x.Int64Field + } + return 0 +} + +type isMessageWithOneOfCustomType_OneOf interface { + isMessageWithOneOfCustomType_OneOf() +} + +type MessageWithOneOfCustomType_StringField struct { + StringField String `protobuf:"bytes,1,opt,name=string_field,json=stringField,proto3,oneof"` +} + +type MessageWithOneOfCustomType_Int64Field struct { + Int64Field Int64 `protobuf:"varint,3,opt,name=int64_field,json=int64Field,proto3,oneof"` +} + +func (*MessageWithOneOfCustomType_StringField) isMessageWithOneOfCustomType_OneOf() {} + +func (*MessageWithOneOfCustomType_Int64Field) isMessageWithOneOfCustomType_OneOf() {} + +type MessageWithRepeatedCustomType struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + RepeatedStringField []String `protobuf:"bytes,1,rep,name=repeated_string_field,json=repeatedStringField,proto3" json:"repeated_string_field,omitempty"` +} + +func (x *MessageWithRepeatedCustomType) Reset() { + *x = MessageWithRepeatedCustomType{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_field_types_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithRepeatedCustomType) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithRepeatedCustomType) ProtoMessage() {} + +func (x *MessageWithRepeatedCustomType) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_field_types_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithRepeatedCustomType.ProtoReflect.Descriptor instead. +func (*MessageWithRepeatedCustomType) Descriptor() ([]byte, []int) { + return file_tests_message_message_field_types_proto_rawDescGZIP(), []int{2} +} + +func (x *MessageWithRepeatedCustomType) GetRepeatedStringField() []String { + if x != nil { + return x.RepeatedStringField + } + return nil +} + +type MessageWithCustomRepeatedType struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + RepeatedStringField Strings `protobuf:"bytes,1,rep,name=repeated_string_field,json=repeatedStringField,proto3" json:"repeated_string_field,omitempty"` +} + +func (x *MessageWithCustomRepeatedType) Reset() { + *x = MessageWithCustomRepeatedType{} + if protoimpl.UnsafeEnabled { + mi := &file_tests_message_message_field_types_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *MessageWithCustomRepeatedType) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MessageWithCustomRepeatedType) ProtoMessage() {} + +func (x *MessageWithCustomRepeatedType) ProtoReflect() protoreflect.Message { + mi := &file_tests_message_message_field_types_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MessageWithCustomRepeatedType.ProtoReflect.Descriptor instead. +func (*MessageWithCustomRepeatedType) Descriptor() ([]byte, []int) { + return file_tests_message_message_field_types_proto_rawDescGZIP(), []int{3} +} + +func (x *MessageWithCustomRepeatedType) GetRepeatedStringField() Strings { + if x != nil { + return x.RepeatedStringField + } + return nil +} + +var File_tests_message_message_field_types_proto protoreflect.FileDescriptor + +var file_tests_message_message_field_types_proto_rawDesc = []byte{ + 0x0a, 0x27, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2f, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x74, 0x79, + 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x74, 0x65, 0x73, 0x74, 0x73, + 0x2e, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0e, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, + 0x67, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe6, 0x02, 0x0a, 0x16, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x57, 0x69, 0x74, 0x68, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x54, 0x79, + 0x70, 0x65, 0x73, 0x12, 0x2f, 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, + 0x65, 0x6c, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, + 0x06, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, + 0x69, 0x65, 0x6c, 0x64, 0x12, 0x2c, 0x0a, 0x0b, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x5f, 0x66, 0x69, + 0x65, 0x6c, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, + 0x05, 0x49, 0x6e, 0x74, 0x33, 0x32, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x46, 0x69, 0x65, + 0x6c, 0x64, 0x12, 0x2c, 0x0a, 0x0b, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, 0x66, 0x69, 0x65, 0x6c, + 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, 0x49, + 0x6e, 0x74, 0x36, 0x34, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x46, 0x69, 0x65, 0x6c, 0x64, + 0x12, 0x2c, 0x0a, 0x0b, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x02, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, 0x46, 0x6c, 0x6f, + 0x61, 0x74, 0x52, 0x0a, 0x66, 0x6c, 0x6f, 0x61, 0x74, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, 0x2f, + 0x0a, 0x0c, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x01, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x44, 0x6f, 0x75, 0x62, + 0x6c, 0x65, 0x52, 0x0b, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x12, + 0x2f, 0x0a, 0x0c, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x0d, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x55, 0x69, 0x6e, + 0x74, 0x33, 0x32, 0x52, 0x0b, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x46, 0x69, 0x65, 0x6c, 0x64, + 0x12, 0x2f, 0x0a, 0x0c, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, + 0x18, 0x07, 0x20, 0x01, 0x28, 0x04, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x55, 0x69, + 0x6e, 0x74, 0x36, 0x34, 0x52, 0x0b, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x46, 0x69, 0x65, 0x6c, + 0x64, 0x22, 0x89, 0x01, 0x0a, 0x1a, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x57, 0x69, 0x74, + 0x68, 0x4f, 0x6e, 0x65, 0x4f, 0x66, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x31, 0x0a, 0x0c, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0c, 0xca, 0xb5, 0x03, 0x08, 0x1a, 0x06, 0x53, 0x74, + 0x72, 0x69, 0x6e, 0x67, 0x48, 0x00, 0x52, 0x0b, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, 0x69, + 0x65, 0x6c, 0x64, 0x12, 0x2e, 0x0a, 0x0b, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x5f, 0x66, 0x69, 0x65, + 0x6c, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x42, 0x0b, 0xca, 0xb5, 0x03, 0x07, 0x1a, 0x05, + 0x49, 0x6e, 0x74, 0x36, 0x34, 0x48, 0x00, 0x52, 0x0a, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x46, 0x69, + 0x65, 0x6c, 0x64, 0x42, 0x08, 0x0a, 0x06, 0x6f, 0x6e, 0x65, 0x5f, 0x6f, 0x66, 0x22, 0x63, 0x0a, + 0x1d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x57, 0x69, 0x74, 0x68, 0x52, 0x65, 0x70, 0x65, + 0x61, 0x74, 0x65, 0x64, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x54, 0x79, 0x70, 0x65, 0x12, 0x42, + 0x0a, 0x15, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x72, 0x69, 0x6e, + 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x42, 0x0e, 0xca, + 0xb5, 0x03, 0x0a, 0x1a, 0x08, 0x5b, 0x5d, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x52, 0x13, 0x72, + 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x46, 0x69, 0x65, + 0x6c, 0x64, 0x22, 0x62, 0x0a, 0x1d, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x57, 0x69, 0x74, + 0x68, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x52, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x54, + 0x79, 0x70, 0x65, 0x12, 0x41, 0x0a, 0x15, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, + 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x01, 0x20, 0x03, + 0x28, 0x09, 0x42, 0x0d, 0xca, 0xb5, 0x03, 0x09, 0x1a, 0x07, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, + 0x73, 0x52, 0x13, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x53, 0x74, 0x72, 0x69, 0x6e, + 0x67, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x42, 0x2a, 0x5a, 0x28, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, + 0x61, 0x74, 0x63, 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_tests_message_message_field_types_proto_rawDescOnce sync.Once + file_tests_message_message_field_types_proto_rawDescData = file_tests_message_message_field_types_proto_rawDesc +) + +func file_tests_message_message_field_types_proto_rawDescGZIP() []byte { + file_tests_message_message_field_types_proto_rawDescOnce.Do(func() { + file_tests_message_message_field_types_proto_rawDescData = protoimpl.X.CompressGZIP(file_tests_message_message_field_types_proto_rawDescData) + }) + return file_tests_message_message_field_types_proto_rawDescData +} + +var file_tests_message_message_field_types_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_tests_message_message_field_types_proto_goTypes = []interface{}{ + (*MessageWithCustomTypes)(nil), // 0: tests.message.MessageWithCustomTypes + (*MessageWithOneOfCustomType)(nil), // 1: tests.message.MessageWithOneOfCustomType + (*MessageWithRepeatedCustomType)(nil), // 2: tests.message.MessageWithRepeatedCustomType + (*MessageWithCustomRepeatedType)(nil), // 3: tests.message.MessageWithCustomRepeatedType +} +var file_tests_message_message_field_types_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_tests_message_message_field_types_proto_init() } +func file_tests_message_message_field_types_proto_init() { + if File_tests_message_message_field_types_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_tests_message_message_field_types_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithCustomTypes); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_tests_message_message_field_types_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithOneOfCustomType); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_tests_message_message_field_types_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithRepeatedCustomType); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_tests_message_message_field_types_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*MessageWithCustomRepeatedType); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_tests_message_message_field_types_proto_msgTypes[1].OneofWrappers = []interface{}{ + (*MessageWithOneOfCustomType_StringField)(nil), + (*MessageWithOneOfCustomType_Int64Field)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_tests_message_message_field_types_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_tests_message_message_field_types_proto_goTypes, + DependencyIndexes: file_tests_message_message_field_types_proto_depIdxs, + MessageInfos: file_tests_message_message_field_types_proto_msgTypes, + }.Build() + File_tests_message_message_field_types_proto = out.File + file_tests_message_message_field_types_proto_rawDesc = nil + file_tests_message_message_field_types_proto_goTypes = nil + file_tests_message_message_field_types_proto_depIdxs = nil +} diff --git a/tests/message/message_field_types.proto b/tests/message/message_field_types.proto new file mode 100644 index 0000000..97eef60 --- /dev/null +++ b/tests/message/message_field_types.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package tests.message; + +import "patch/go.proto"; + +option go_package = "github.com/alta/protopatch/tests/message"; + +message MessageWithCustomTypes { + string string_field = 1 [(go.field).type = "String"]; + int32 int32_field = 2 [(go.field).type = "Int32"]; + int64 int64_field = 3 [(go.field).type = "Int64"]; + float float_field = 4 [(go.field).type = "Float"]; + double double_field = 5 [(go.field).type = "Double"]; + uint32 uint32_field = 6 [(go.field).type = "Uint32"]; + uint64 uint64_field = 7 [(go.field).type = "Uint64"]; +} + +message MessageWithOneOfCustomType { + oneof one_of { + string string_field = 1 [(go.field).type = "String"]; + int64 int64_field = 3 [(go.field).type = "Int64"]; + } +} + +message MessageWithRepeatedCustomType { + repeated string repeated_string_field = 1 [(go.field).type = "[]String"]; +} + +message MessageWithCustomRepeatedType { + repeated string repeated_string_field = 1 [(go.field).type = "Strings"]; +} diff --git a/tests/message/message_renames.proto b/tests/message/message_renames.proto index cb91203..4414dda 100644 --- a/tests/message/message_renames.proto +++ b/tests/message/message_renames.proto @@ -36,9 +36,9 @@ message MessageWithRenamedField { } message MessageWithEmbeddedField { - Embedded embedded_message = 5 [(go.field).embed = true]; + Embedded embedded_message = 5 [(go.field).embed = true]; } message Embedded { - string message = 1; + string message = 1; } diff --git a/tests/message/message_test.go b/tests/message/message_test.go index 689d871..f9beb89 100644 --- a/tests/message/message_test.go +++ b/tests/message/message_test.go @@ -3,6 +3,7 @@ package message import ( "testing" + "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "github.com/alta/protopatch/tests" @@ -55,13 +56,14 @@ func TestRenamedInnerMessage(t *testing.T) { } func TestMessageWithRenamedField(t *testing.T) { - m := &MessageWithRenamedField{} + m := &MessageWithRenamedField{ + ID: 66, + } tests.ValidateMessage(t, m) var _ int32 = m.ID var _ int32 = m.GetID() } - func TestMessageWithEmbeddedFields(t *testing.T) { message := "noop" m := &MessageWithEmbeddedField{ @@ -107,3 +109,52 @@ func TestExtendedMessage(t *testing.T) { _ = proto.GetExtension(m, ExtGamma).(string) _ = proto.GetExtension(m, ExtDelta).(string) } + +func TestMessageWithCustomTypes(t *testing.T) { + m := &MessageWithCustomTypes{ + StringField: "42", + Int32Field: 42, + Int64Field: 42, + FloatField: 42, + DoubleField: 42, + Uint32Field: 42, + Uint64Field: 42, + } + + tests.ValidateMessage(t, m) + var _ string = string(m.StringField) + var _ int32 = int32(m.Int32Field) + var _ int64 = int64(m.Int64Field) + var _ float32 = float32(m.FloatField) + var _ float64 = float64(m.DoubleField) + var _ uint32 = uint32(m.Uint32Field) + var _ uint64 = uint64(m.Uint64Field) + + assert.Equal(t, String("42"), m.StringField) + assert.Equal(t, Int32(42), m.Int32Field) + assert.Equal(t, Int64(42), m.Int64Field) + assert.Equal(t, Float(42), m.FloatField) + assert.Equal(t, Double(42), m.DoubleField) + assert.Equal(t, Uint32(42), m.Uint32Field) + assert.Equal(t, Uint64(42), m.Uint64Field) +} + +func TestMessageWithRepeatedCustomType(t *testing.T) { + slice := []String{"one", "two"} + m := &MessageWithRepeatedCustomType{ + RepeatedStringField: slice, + } + tests.ValidateMessage(t, m) + var _ []String = m.RepeatedStringField + assert.Equal(t, slice, m.RepeatedStringField) +} + +func TestMessageWithCustomRepeatedType(t *testing.T) { + slice := Strings{"one", "two"} + m := &MessageWithCustomRepeatedType{ + RepeatedStringField: slice, + } + tests.ValidateMessage(t, m) + var _ Strings = m.RepeatedStringField + assert.Equal(t, slice, m.RepeatedStringField) +} diff --git a/tests/plugin/validate.extensions.go b/tests/plugin/validate.extensions.go new file mode 100644 index 0000000..03456a2 --- /dev/null +++ b/tests/plugin/validate.extensions.go @@ -0,0 +1,3 @@ +package plugin + +type IPAddresses []*IPAddress diff --git a/tests/plugin/validate.pb.go b/tests/plugin/validate.pb.go index 7ca9a86..3abf22a 100644 --- a/tests/plugin/validate.pb.go +++ b/tests/plugin/validate.pb.go @@ -8,6 +8,7 @@ package plugin import ( _ "github.com/alta/protopatch/patch/gopb" + github_com_alta_protopatch_tests_message "github.com/alta/protopatch/tests/message" _ "github.com/envoyproxy/protoc-gen-validate/validate" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" @@ -76,9 +77,10 @@ type Interface struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - Status InterfaceStatus `protobuf:"varint,2,opt,name=status,proto3,enum=tests.plugin.Interface_Status" json:"status,omitempty"` - Addresses []*IPAddress `protobuf:"bytes,3,rep,name=addresses,proto3" json:"addresses,omitempty"` + Name github_com_alta_protopatch_tests_message.Name `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Status InterfaceStatus `protobuf:"varint,2,opt,name=status,proto3,enum=tests.plugin.Interface_Status" json:"status,omitempty"` + Addresses IPAddresses `protobuf:"bytes,3,rep,name=addresses,proto3" json:"addresses,omitempty"` + Aliases github_com_alta_protopatch_tests_message.Names `protobuf:"bytes,4,rep,name=aliases,proto3" json:"aliases,omitempty"` } func (x *Interface) Reset() { @@ -113,7 +115,7 @@ func (*Interface) Descriptor() ([]byte, []int) { return file_tests_plugin_validate_proto_rawDescGZIP(), []int{0} } -func (x *Interface) GetName() string { +func (x *Interface) GetName() github_com_alta_protopatch_tests_message.Name { if x != nil { return x.Name } @@ -127,13 +129,20 @@ func (x *Interface) GetStatus() InterfaceStatus { return StatusUnknown } -func (x *Interface) GetAddresses() []*IPAddress { +func (x *Interface) GetAddresses() IPAddresses { if x != nil { return x.Addresses } return nil } +func (x *Interface) GetAliases() github_com_alta_protopatch_tests_message.Names { + if x != nil { + return x.Aliases + } + return nil +} + type IPAddress struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -222,34 +231,47 @@ var file_tests_plugin_validate_proto_rawDesc = []byte{ 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x1a, 0x0e, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x67, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x91, 0x02, 0x0a, 0x09, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, - 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1e, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, - 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x2e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x0a, 0xfa, 0x42, 0x07, 0x82, 0x01, 0x04, 0x10, 0x01, - 0x20, 0x00, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x35, 0x0a, 0x09, 0x61, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, - 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x50, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, - 0x73, 0x22, 0x75, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x20, 0x0a, 0x07, 0x55, - 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x1a, 0x13, 0xca, 0xb5, 0x03, 0x0f, 0x0a, 0x0d, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x12, 0x16, 0x0a, - 0x02, 0x55, 0x50, 0x10, 0x01, 0x1a, 0x0e, 0xca, 0xb5, 0x03, 0x0a, 0x0a, 0x08, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x55, 0x70, 0x12, 0x1a, 0x0a, 0x04, 0x44, 0x4f, 0x57, 0x4e, 0x10, 0x02, 0x1a, - 0x10, 0xca, 0xb5, 0x03, 0x0c, 0x0a, 0x0a, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x44, 0x6f, 0x77, - 0x6e, 0x1a, 0x15, 0xca, 0xb5, 0x03, 0x11, 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, - 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x69, 0x0a, 0x09, 0x49, 0x50, 0x41, 0x64, - 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x27, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x42, 0x11, 0xca, 0xb5, 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x34, 0xfa, - 0x42, 0x04, 0x72, 0x02, 0x78, 0x01, 0x48, 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x28, - 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x42, 0x12, 0xca, 0xb5, - 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x36, 0xfa, 0x42, 0x05, 0x72, 0x03, 0x80, 0x01, 0x01, - 0x48, 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x36, 0x42, 0x09, 0x0a, 0x07, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, - 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x62, 0x06, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xea, 0x03, 0x0a, 0x09, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, + 0x63, 0x65, 0x12, 0x61, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x42, 0x4d, 0xca, 0xb5, 0x03, 0x2f, 0x1a, 0x2d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, + 0x63, 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0xfa, 0x42, 0x17, 0x72, 0x15, 0x10, 0x02, 0x18, 0x0a, 0x32, 0x0f, + 0x5b, 0x30, 0x2d, 0x39, 0x61, 0x2d, 0x7a, 0x41, 0x2d, 0x5a, 0x2e, 0x2d, 0x5f, 0x5d, 0x2a, 0x52, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x42, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1e, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x2e, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x42, 0x0a, 0xfa, 0x42, 0x07, 0x82, 0x01, 0x04, 0x10, 0x01, 0x20, + 0x00, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x48, 0x0a, 0x09, 0x61, 0x64, 0x64, + 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x74, + 0x65, 0x73, 0x74, 0x73, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x2e, 0x49, 0x50, 0x41, 0x64, + 0x64, 0x72, 0x65, 0x73, 0x73, 0x42, 0x11, 0xca, 0xb5, 0x03, 0x0d, 0x1a, 0x0b, 0x49, 0x50, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x65, 0x73, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, + 0x73, 0x65, 0x73, 0x12, 0x75, 0x0a, 0x07, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x65, 0x73, 0x18, 0x04, + 0x20, 0x03, 0x28, 0x09, 0x42, 0x5b, 0xca, 0xb5, 0x03, 0x30, 0x1a, 0x2e, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x2f, 0x6d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0xfa, 0x42, 0x05, 0x92, 0x01, 0x02, + 0x10, 0x0a, 0xfa, 0x42, 0x1c, 0x92, 0x01, 0x19, 0x22, 0x17, 0x72, 0x15, 0x10, 0x02, 0x18, 0x0a, + 0x32, 0x0f, 0x5b, 0x30, 0x2d, 0x39, 0x61, 0x2d, 0x7a, 0x41, 0x2d, 0x5a, 0x2e, 0x2d, 0x5f, 0x5d, + 0x2a, 0x52, 0x07, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x65, 0x73, 0x22, 0x75, 0x0a, 0x06, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x20, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x1a, 0x13, 0xca, 0xb5, 0x03, 0x0f, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, + 0x6e, 0x6b, 0x6e, 0x6f, 0x77, 0x6e, 0x12, 0x16, 0x0a, 0x02, 0x55, 0x50, 0x10, 0x01, 0x1a, 0x0e, + 0xca, 0xb5, 0x03, 0x0a, 0x0a, 0x08, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x12, 0x1a, + 0x0a, 0x04, 0x44, 0x4f, 0x57, 0x4e, 0x10, 0x02, 0x1a, 0x10, 0xca, 0xb5, 0x03, 0x0c, 0x0a, 0x0a, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x44, 0x6f, 0x77, 0x6e, 0x1a, 0x15, 0xca, 0xb5, 0x03, 0x11, + 0x0a, 0x0f, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x22, 0x69, 0x0a, 0x09, 0x49, 0x50, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x27, + 0x0a, 0x04, 0x69, 0x70, 0x76, 0x34, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x42, 0x11, 0xca, 0xb5, + 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, 0x34, 0xfa, 0x42, 0x04, 0x72, 0x02, 0x78, 0x01, 0x48, + 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, 0x34, 0x12, 0x28, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x42, 0x12, 0xca, 0xb5, 0x03, 0x06, 0x0a, 0x04, 0x49, 0x50, 0x56, + 0x36, 0xfa, 0x42, 0x05, 0x72, 0x03, 0x80, 0x01, 0x01, 0x48, 0x00, 0x52, 0x04, 0x69, 0x70, 0x76, + 0x36, 0x42, 0x09, 0x0a, 0x07, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x42, 0x29, 0x5a, 0x27, + 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x6c, 0x74, 0x61, 0x2f, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x70, 0x61, 0x74, 0x63, 0x68, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, + 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/tests/plugin/validate.pb.validate.go b/tests/plugin/validate.pb.validate.go index 4ba1073..d1886ab 100644 --- a/tests/plugin/validate.pb.validate.go +++ b/tests/plugin/validate.pb.validate.go @@ -40,7 +40,19 @@ func (m *Interface) Validate() error { return nil } - // no validation rules for Name + if l := utf8.RuneCountInString(string(m.GetName())); l < 2 || l > 10 { + return InterfaceValidationError{ + field: "Name", + reason: "value length must be between 2 and 10 runes, inclusive", + } + } + + if !_Interface_Name_Pattern.MatchString(string(m.GetName())) { + return InterfaceValidationError{ + field: "Name", + reason: "value does not match regex pattern \"[0-9a-zA-Z.-_]*\"", + } + } if _, ok := _Interface_Status_NotInLookup[m.GetStatus()]; ok { return InterfaceValidationError{ @@ -71,6 +83,32 @@ func (m *Interface) Validate() error { } + if len([]string(m.GetAliases())) > 10 { + return InterfaceValidationError{ + field: "Aliases", + reason: "value must contain no more than 10 item(s)", + } + } + + for idx, item := range m.GetAliases() { + _, _ = idx, item + + if l := utf8.RuneCountInString(item); l < 2 || l > 10 { + return InterfaceValidationError{ + field: fmt.Sprintf("Aliases[%v]", idx), + reason: "value length must be between 2 and 10 runes, inclusive", + } + } + + if !_Interface_Aliases_Pattern.MatchString(item) { + return InterfaceValidationError{ + field: fmt.Sprintf("Aliases[%v]", idx), + reason: "value does not match regex pattern \"[0-9a-zA-Z.-_]*\"", + } + } + + } + return nil } @@ -128,10 +166,14 @@ var _ interface { ErrorName() string } = InterfaceValidationError{} +var _Interface_Name_Pattern = regexp.MustCompile("[0-9a-zA-Z.-_]*") + var _Interface_Status_NotInLookup = map[InterfaceStatus]struct{}{ 0: {}, } +var _Interface_Aliases_Pattern = regexp.MustCompile("[0-9a-zA-Z.-_]*") + // Validate checks the field values on IPAddress with the rules defined in the // proto definition for this message. If any rules are violated, an error is returned. func (m *IPAddress) Validate() error { diff --git a/tests/plugin/validate.proto b/tests/plugin/validate.proto index 0dd5129..a2f5a4b 100644 --- a/tests/plugin/validate.proto +++ b/tests/plugin/validate.proto @@ -8,7 +8,7 @@ import "validate/validate.proto"; option go_package = "github.com/alta/protopatch/tests/plugin"; message Interface { - string name = 1; + string name = 1 [(go.field).type = "github.com/alta/protopatch/tests/message.Name", (validate.rules).string = {min_len: 2, max_len: 10, pattern: "[0-9a-zA-Z.-_]*"}]; enum Status { option (go.enum).name = "InterfaceStatus"; UNKNOWN = 0 [(go.value).name = "StatusUnknown"]; @@ -19,7 +19,12 @@ message Interface { defined_only: true, not_in: [0] }]; - repeated IPAddress addresses = 3; + repeated IPAddress addresses = 3 [(go.field).type = "IPAddresses"]; + repeated string aliases = 4 [ + (go.field).type = "github.com/alta/protopatch/tests/message.Names", + (validate.rules).repeated = {max_items: 10}, + (validate.rules).repeated.items.string = {min_len: 2, max_len: 10, pattern: "[0-9a-zA-Z.-_]*"} + ]; } message IPAddress { diff --git a/tests/plugin/validate_test.go b/tests/plugin/validate_test.go index b502df7..e619cb9 100644 --- a/tests/plugin/validate_test.go +++ b/tests/plugin/validate_test.go @@ -33,9 +33,10 @@ func TestInterfaceValidate(t *testing.T) { wantErr bool }{ {"nil", nil, false}, // Weird, but OK - {"unknown", &Interface{Status: StatusUnknown}, true}, - {"up", &Interface{Status: StatusUp, Addresses: nil}, false}, - {"down", &Interface{Status: StatusDown, Addresses: nil}, false}, + {"unknown", &Interface{Name: "eth0", Status: StatusUnknown}, true}, + {"up", &Interface{Name: "eth0", Status: StatusUp, Addresses: nil}, false}, + {"down", &Interface{Name: "eth0", Status: StatusDown, Addresses: nil}, false}, + {"invalid name", &Interface{Name: "a", Status: StatusDown, Addresses: nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -44,6 +45,7 @@ func TestInterfaceValidate(t *testing.T) { t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) return } + var _ IPAddresses = tt.i.GetAddresses() }) } }