Skip to content

Commit

Permalink
add casttype support for oneof scalars fields
Browse files Browse the repository at this point in the history
  • Loading branch information
Adphi committed Sep 24, 2021
1 parent 0da4a9e commit a16c2e5
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ proto_includes = \
protos: $(proto_files)

.PHONY: $(proto_files)
$(proto_files): # tools Makefile
$(proto_files): tools Makefile
# protoc-gen-go
protoc \
$(proto_includes) \
Expand Down
4 changes: 2 additions & 2 deletions patch/go.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ message Options {
optional bool embed = 2;

// The casstype option changes the generated field type.
// All generated code assumes that this type is castable to the protocol buffer field type.
// It does not work for structs or enums.
// All generated code assumes that this type is castable to the protocol buffer field type,
// so it does not work for messages types.
// Not supported for repeated fields.
optional string casttype = 3;

Expand Down
4 changes: 2 additions & 2 deletions patch/gopb/go.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion patch/patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,13 @@ func (p *Patcher) scanField(f *protogen.Field) {
// check casttype
if typ := opts.GetCasttype(); typ != "" {
switch {
case f.Desc.IsList():
log.Printf("Warning: casttype declared for repeated field: %s", f.Desc.Name())
case f.Message != nil:
log.Printf("Warning: casttype declared for message field: %s", f.Desc.Name())
case f.Oneof != nil:
log.Printf("Warning: casttype for oneof field not supported: %s", f.Desc.Name())
p.CastType(ident.WithChild(f.GoIdent, f.GoName), typ)
p.CastType(ident.WithChild(m.GoIdent, "Get"+f.GoName), typ)
default:
p.CastType(ident.WithChild(m.GoIdent, f.GoName), typ)
p.CastType(ident.WithChild(m.GoIdent, "Get"+f.GoName), typ)
Expand Down
4 changes: 2 additions & 2 deletions tests/message/message.extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ type Float float32

type Double float64

type UInt32 uint32
type Uint32 uint32

type UInt64 uint64
type Uint64 uint64
199 changes: 185 additions & 14 deletions tests/message/message_casttypes.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions tests/message/message_casttypes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ message MessageWithCustomTypes {
int64 int64_field = 3 [(go.field).casttype = "Int64"];
float float_field = 4 [(go.field).casttype = "Float"];
double double_field = 5 [(go.field).casttype = "Double"];
uint32 uint32_field = 6 [(go.field).casttype = "UInt32"];
uint64 uint64_field = 7 [(go.field).casttype = "UInt64"];
uint32 uint32_field = 6 [(go.field).casttype = "Uint32"];
uint64 uint64_field = 7 [(go.field).casttype = "Uint64"];
}

message MessageWithOneOfCustomType {
oneof one_of {
string string_field = 1 [(go.field).casttype = "String"];
int64 int64_field = 3 [(go.field).casttype = "Int64"];
}
}
9 changes: 9 additions & 0 deletions tests/message/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package message
import (
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"

"github.com/alta/protopatch/tests"
Expand Down Expand Up @@ -128,4 +129,12 @@ func TestMessageWithCustomTypes(t *testing.T) {
var _ float64 = float64(m.DoubleField)
var _ uint32 = uint32(m.Uint32Field)
var _ uint64 = uint64(m.Uint64Field)

assert.Equal(t, String("42"), m.StringField)
assert.Equal(t, Int32(42), m.Int32Field)
assert.Equal(t, Int64(42), m.Int64Field)
assert.Equal(t, Float(42), m.FloatField)
assert.Equal(t, Double(42), m.DoubleField)
assert.Equal(t, Uint32(42), m.Uint32Field)
assert.Equal(t, Uint64(42), m.Uint64Field)
}

0 comments on commit a16c2e5

Please sign in to comment.