Skip to content

Commit

Permalink
add casttype support for repeated scalar fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Adphi committed Sep 24, 2021
1 parent 416f079 commit eadb099
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 54 deletions.
37 changes: 30 additions & 7 deletions patch/casttype.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,41 @@ func (p *Patcher) patchTypeDef(id *ast.Ident, obj types.Object) {
astutil.AddNamedImport(p.fset, f, pkgImport, pkg)
castType = pkgImport + "." + name
}

castDecl := func(v *ast.Field) bool {
switch t := v.Type.(type){
case *ast.Ident:
t.Name = castType
return true
case *ast.ArrayType:
if strings.HasPrefix(castType, "[]") {
if id, ok := t.Elt.(*ast.Ident); ok {
id.Name = strings.TrimPrefix(castType, "[]")
return true
}
} else {
v.Type = &ast.Ident{
Name: castType,
}
return true
}
default:
return false
}
return false
}

// Cast Field definition
if id.Obj != nil && id.Obj.Decl != nil {
v, ok := id.Obj.Decl.(*ast.Field)
if !ok {
log.Printf("Warning: casttype declared for non-field object: %v `%s`", obj, castType)
return
}
t, ok := v.Type.(*ast.Ident)
if ok {
t.Name = castType
return
if !castDecl(v) {
log.Printf("Warning: unsupported casttype type: %T `%s`", v.Type, castType)
}
return
}
switch obj.Type().(type) {
// Cast Getter signature
Expand All @@ -47,10 +70,10 @@ func (p *Patcher) patchTypeDef(id *ast.Ident, obj types.Object) {
log.Printf("Warning: unexpected return count for getter: %v `%d`", obj, l)
return
}
if ident, ok := n.Type.Results.List[0].Type.(*ast.Ident); ok {
ident.Name = castType
return
if !castDecl(n.Type.Results.List[0]) {
log.Printf("Warning: unsupported casttype type: %T `%s`", n.Type.Results.List[0].Type, castType)
}
return
}
}

Expand Down
2 changes: 1 addition & 1 deletion patch/go.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ message Options {
// The casstype option changes the generated field type.
// All generated code assumes that this type is castable to the protocol buffer field type,
// so it does not work for messages types.
// Not supported for repeated fields.
// Rrepeated fields are also supported, both as '[]Type' or 'Types' where Types is a named slice type.
optional string casttype = 3;

// The getter option renames the generated getter method (default: Get<Field>)
Expand Down
2 changes: 1 addition & 1 deletion patch/gopb/go.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions patch/patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,6 @@ func (p *Patcher) scanField(f *protogen.Field) {
// check casttype
if typ := opts.GetCasttype(); typ != "" {
switch {
case f.Desc.IsList():
log.Printf("Warning: casttype declared for repeated field: %s", f.Desc.Name())
case f.Message != nil:
log.Printf("Warning: casttype declared for message field: %s", f.Desc.Name())
case f.Oneof != nil:
Expand Down
3 changes: 3 additions & 0 deletions tests/message/message.extensions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package message

type Name string
type Names []string

type Int32 int32

Expand All @@ -15,3 +16,5 @@ type Double float64
type Uint32 uint32

type Uint64 uint64

type Strings []string
105 changes: 86 additions & 19 deletions tests/message/message_casttypes.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions tests/message/message_casttypes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import "patch/go.proto";

option go_package = "github.com/alta/protopatch/tests/message";



message MessageWithCustomTypes {
string string_field = 1 [(go.field).casttype = "String"];
int32 int32_field = 2 [(go.field).casttype = "Int32"];
Expand All @@ -24,3 +22,11 @@ message MessageWithOneOfCustomType {
int64 int64_field = 3 [(go.field).casttype = "Int64"];
}
}

message MessageWithCustomRepeatedTypes {
repeated string repeated_string_field = 1 [(go.field).casttype = "[]String"];
}

message MessageWithRepeatedCustomTypes {
repeated string repeated_string_field = 1 [(go.field).casttype = "Strings"];
}
20 changes: 20 additions & 0 deletions tests/message/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,23 @@ func TestMessageWithCustomTypes(t *testing.T) {
assert.Equal(t, Uint32(42), m.Uint32Field)
assert.Equal(t, Uint64(42), m.Uint64Field)
}

func TestMessageWithCustomRepeatedTypes(t *testing.T) {
slice := []String{"one", "two"}
m := &MessageWithCustomRepeatedTypes{
RepeatedStringField: slice,
}
tests.ValidateMessage(t, m)
var _ []String = m.RepeatedStringField
assert.Equal(t, slice, m.RepeatedStringField)
}

func TestMessageWithRepeatedCustomTypes(t *testing.T) {
slice := Strings{"one", "two"}
m := &MessageWithRepeatedCustomTypes{
RepeatedStringField: slice,
}
tests.ValidateMessage(t, m)
var _ Strings = m.RepeatedStringField
assert.Equal(t, slice, m.RepeatedStringField)
}
Loading

0 comments on commit eadb099

Please sign in to comment.