Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Casttypewith #659

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ regenerate:
make -C test/issue620 regenerate
make -C test/protobuffer regenerate
make -C test/issue630 regenerate
make -C test/casttypewith regenerate

make gofmt

Expand Down
177 changes: 94 additions & 83 deletions gogoproto/gogo.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions gogoproto/gogo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,5 @@ extend google.protobuf.FieldOptions {
optional bool stdduration = 65011;
optional bool wktpointer = 65012;

optional string casttypewith = 65013;
}
30 changes: 30 additions & 0 deletions gogoproto/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ func IsCastType(field *google_protobuf.FieldDescriptorProto) bool {
return false
}

func IsCastTypeWith(field *google_protobuf.FieldDescriptorProto) bool {
typ := GetCastTypeWith(field)
if len(typ) > 0 {
return true
}
return false
}

func HasCastTypeWith(file *google_protobuf.FileDescriptorProto, message *google_protobuf.DescriptorProto) bool {
for _, f := range message.GetField() {
if IsCastTypeWith(f) {
return true
}
}
return false
}

func IsCastKey(field *google_protobuf.FieldDescriptorProto) bool {
typ := GetCastKey(field)
if len(typ) > 0 {
Expand Down Expand Up @@ -173,6 +190,19 @@ func GetCastType(field *google_protobuf.FieldDescriptorProto) string {
return ""
}

func GetCastTypeWith(field *google_protobuf.FieldDescriptorProto) string {
if field == nil {
return ""
}
if field.Options != nil {
v, err := proto.GetExtension(field.Options, E_Casttypewith)
if err == nil && v.(*string) != nil {
return *(v.(*string))
}
}
return ""
}

func GetCastKey(field *google_protobuf.FieldDescriptorProto) string {
if field == nil {
return ""
Expand Down
21 changes: 20 additions & 1 deletion plugin/equal/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,26 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
isTimestamp := gogoproto.IsStdTime(field)
// oneof := field.OneofIndex != nil
if !repeated {
if ctype || isTimestamp {
if gogoproto.IsCastTypeWith(field) {
_, _, _, casterTyp, err := generator.GetCastTypeWith(field)
if err != nil {
panic(err)
}
p.P(`{`)
p.In()
p.P("__caster := &", casterTyp, "{}")
p.P(`if !__caster.Equal(this.`, fieldname, `, that1.`, fieldname, `){`)
p.In()
if verbose {
p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`)
} else {
p.P(`return false`)
}
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
} else if ctype || isTimestamp {
if nullable {
p.P(`if that1.`, fieldname, ` == nil {`)
p.In()
Expand Down
28 changes: 27 additions & 1 deletion plugin/marshalto/marshalto.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,33 @@ func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.Fi
p.encodeKey(fieldNumber, wireType)
}
case descriptor.FieldDescriptorProto_TYPE_BYTES:
if !gogoproto.IsCustomType(field) {
if gogoproto.IsCastTypeWith(field) {
if !nullable || repeated {
panic("casttypewith only supports single pointers")
}
_, _, _, casterTyp, err := generator.GetCastTypeWith(field)
if err != nil {
panic(err)
}
p.P(`{`)
p.In()
p.P("__caster := &", casterTyp, "{}")
if protoSizer {
p.P(`size := __caster.ProtoSize(m.`, fieldname, `)`)
} else {
p.P(`size := __caster.Size(m.`, fieldname, `)`)
}
p.P(`i -= size`)
p.P(`if _, err := __caster.MarshalTo(m.`, fieldname, `, dAtA[i:]); err != nil {`)
p.In()
p.P(`return 0, err`)
p.Out()
p.P(`}`)
p.Out()
p.callVarint(`size`)
p.P(`}`)
p.encodeKey(fieldNumber, wireType)
} else if !gogoproto.IsCustomType(field) {
if repeated {
val := p.reverseListRange(`m.`, fieldname)
p.P(`i -= len(`, val, `)`)
Expand Down
11 changes: 11 additions & 0 deletions plugin/populate/populate.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@ func (p *plugin) GenerateField(file *generator.FileDescriptor, message *generato
p.P(p.varGen.Next(), `:= `, funcCall)
p.P(`this.`, fieldname, ` = *`, p.varGen.Current())
}
} else if gogoproto.IsCastTypeWith(field) {
_, _, _, casterTyp, err := generator.GetCastTypeWith(field)
if err != nil {
panic(err)
}
p.P(`{`)
p.In()
p.P("__caster := &", casterTyp, "{}")
p.P(`this.`, fieldname, ` = __caster.NewPopulated()`)
p.Out()
p.P(`}`)
} else if field.IsMessage() || p.IsGroup(field) {
funcCall := p.getFuncCall(goTypName, field)
if field.IsRepeated() {
Expand Down
14 changes: 13 additions & 1 deletion plugin/size/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,19 @@ func (p *size) generateField(proto3 bool, file *generator.FileDescriptor, messag
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
}
case descriptor.FieldDescriptorProto_TYPE_BYTES:
if !gogoproto.IsCustomType(field) {
if gogoproto.IsCastTypeWith(field) {
_, _, _, casterTyp, err := generator.GetCastTypeWith(field)
if err != nil {
panic(err)
}
p.P(`{`)
p.In()
p.P("__caster := &", casterTyp, "{}")
p.P(`l = __caster.Size(m.`, fieldname, `)`)
p.P(`n+=`, strconv.Itoa(key), `+l+sov`, p.localName, `(uint64(l))`)
p.Out()
p.P(`}`)
} else if !gogoproto.IsCustomType(field) {
if repeated {
p.P(`for _, b := range m.`, fieldname, ` { `)
p.In()
Expand Down
10 changes: 10 additions & 0 deletions plugin/testgen/testgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,11 @@ func (p *testJson) Generate(imports generator.PluginImports, file *generator.Fil
if message.DescriptorProto.GetOptions().GetMapEntry() {
continue
}

if gogoproto.HasCastTypeWith(file.FileDescriptorProto, message.DescriptorProto) {
continue
}

if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) {
used = true
p.P(`func Test`, ccTypeName, `JSON(t *`, testingPkg.Use(), `.T) {`)
Expand Down Expand Up @@ -537,6 +542,11 @@ func (p *testText) Generate(imports generator.PluginImports, file *generator.Fil
if message.DescriptorProto.GetOptions().GetMapEntry() {
continue
}

if gogoproto.HasCastTypeWith(file.FileDescriptorProto, message.DescriptorProto) {
continue
}

if gogoproto.HasTestGen(file.FileDescriptorProto, message.DescriptorProto) {
used = true

Expand Down
25 changes: 24 additions & 1 deletion plugin/unmarshal/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,30 @@ func (p *unmarshal) field(file *generator.FileDescriptor, msg *generator.Descrip
p.P(`return `, p.ioPkg.Use(), `.ErrUnexpectedEOF`)
p.Out()
p.P(`}`)
if !gogoproto.IsCustomType(field) {
if gogoproto.IsCastTypeWith(field) {
// for now only support pointers, non-repeating
if oneof || !nullable || repeated {
panic("casttypewith only supports single pointers")
}
_, _, _, casterTyp, err := generator.GetCastTypeWith(field)
if err != nil {
panic(err)
}
p.P(`{`)
p.In()
p.P("__caster := &", casterTyp, "{}")
p.P("if tmp, err := __caster.Unmarshal(dAtA[iNdEx:postIndex]); err != nil {")
p.In()
p.P(`return err`)
p.Out()
p.P(`} else { `)
p.In()
p.P(`m.`, fieldname, "= tmp")
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
} else if !gogoproto.IsCustomType(field) {
if oneof {
p.P(`v := make([]byte, postIndex-iNdEx)`)
p.P(`copy(v, dAtA[iNdEx:postIndex])`)
Expand Down
54 changes: 48 additions & 6 deletions protoc-gen-gogo/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,11 @@ func (g *Generator) goTag(message *Descriptor, field *descriptor.FieldDescriptor
casttype = ",casttype=" + gogoproto.GetCastType(field)
}

casttypewith := ""
if gogoproto.IsCastTypeWith(field) {
casttype = ",casttypewith=" + gogoproto.GetCastTypeWith(field)
}

castkey := ""
if gogoproto.IsCastKey(field) {
castkey = ",castkey=" + gogoproto.GetCastKey(field)
Expand Down Expand Up @@ -1756,7 +1761,7 @@ func (g *Generator) goTag(message *Descriptor, field *descriptor.FieldDescriptor
if gogoproto.IsWktPtr(field) {
wktptr = ",wktptr"
}
return strconv.Quote(fmt.Sprintf("%s,%d,%s%s%s%s%s%s%s%s%s%s%s%s%s%s",
return strconv.Quote(fmt.Sprintf("%s,%d,%s%s%s%s%s%s%s%s%s%s%s%s%s%s%s",
wiretype,
field.GetNumber(),
optrepreq,
Expand All @@ -1768,6 +1773,7 @@ func (g *Generator) goTag(message *Descriptor, field *descriptor.FieldDescriptor
embed,
ctype,
casttype,
casttypewith,
castkey,
castvalue,
stdtime,
Expand Down Expand Up @@ -1879,6 +1885,27 @@ func (g *Generator) GoType(message *Descriptor, field *descriptor.FieldDescripto
if len(packageName) > 0 {
g.customImports = append(g.customImports, packageName)
}
case gogoproto.IsCastTypeWith(field) && (gogoproto.IsCustomType(field) || gogoproto.IsCastType(field)):
g.Fail(field.GetName() + " casttypewith is incompatible with customtype and casttype")
case gogoproto.IsCastTypeWith(field):
var casteePkg string
var casterPkg string
var err error
if field.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES {
casteePkg, typ, casterPkg, _, err = getCastTypeWith(field)
if err != nil {
g.Fail(err.Error())
}
if len(casteePkg) > 0 {
g.customImports = append(g.customImports, casteePkg)
}
if len(casterPkg) > 0 {
g.customImports = append(g.customImports, casterPkg)
}
typ = "*" + typ
} else {
g.Fail(field.GetName() + " casttypewith only works with bytes")
}
case gogoproto.IsStdTime(field):
g.customImports = append(g.customImports, "time")
typ = "time.Time"
Expand Down Expand Up @@ -1963,6 +1990,21 @@ func (g *Generator) GoMapType(d *Descriptor, field *descriptor.FieldDescriptorPr
return m
}

if gogoproto.IsCastTypeWith(field) {
var packageName string
var typ string
var err error
packageName, typ, _, _, err = getCastTypeWith(field)
if err != nil {
g.Fail(err.Error())
}
if len(packageName) > 0 {
g.customImports = append(g.customImports, packageName)
}
m.GoType = typ
return m
}

// We don't use stars, except for message-typed values.
// Message and enum types are the only two possibly foreign types used in maps,
// so record their use. They are not permitted as map keys.
Expand All @@ -1975,7 +2017,7 @@ func (g *Generator) GoMapType(d *Descriptor, field *descriptor.FieldDescriptorPr
if !gogoproto.IsNullable(m.ValueAliasField) {
valType = strings.TrimPrefix(valType, "*")
}
if !gogoproto.IsStdType(m.ValueAliasField) && !gogoproto.IsCustomType(field) && !gogoproto.IsCastType(field) {
if !gogoproto.IsStdType(m.ValueAliasField) && !gogoproto.IsCustomType(field) && !gogoproto.IsCastType(field) && !gogoproto.IsCastTypeWith(field) {
g.RecordTypeUse(m.ValueAliasField.GetTypeName())
}
default:
Expand Down Expand Up @@ -2583,7 +2625,7 @@ func (g *Generator) generateOneofDecls(mc *msgCtx, topLevelFields []topLevelFiel
for i, sf := range of.subFields {
fieldFullPath := fmt.Sprintf("%s,%d,%d", mc.message.path, messageFieldPath, i)
g.P("type ", Annotate(mc.message.file, fieldFullPath, sf.oneofTypeName), " struct{ ", Annotate(mc.message.file, fieldFullPath, sf.goName), " ", sf.goType, " `", sf.tags, "` }")
if !gogoproto.IsStdType(sf.protoField) && !gogoproto.IsCustomType(sf.protoField) && !gogoproto.IsCastType(sf.protoField) {
if !gogoproto.IsStdType(sf.protoField) && !gogoproto.IsCustomType(sf.protoField) && !gogoproto.IsCastType(sf.protoField) && !gogoproto.IsCastTypeWith(sf.protoField) {
g.RecordTypeUse(sf.protoField.GetTypeName())
}
}
Expand Down Expand Up @@ -2950,7 +2992,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
}

oneofField.subFields = append(oneofField.subFields, &sf)
if !gogoproto.IsStdType(field) && !gogoproto.IsCustomType(field) && !gogoproto.IsCastType(field) {
if !gogoproto.IsStdType(field) && !gogoproto.IsCustomType(field) && !gogoproto.IsCastType(field) && !gogoproto.IsCastTypeWith(field) {
g.RecordTypeUse(field.GetTypeName())
}
continue
Expand Down Expand Up @@ -2983,15 +3025,15 @@ func (g *Generator) generateMessage(message *Descriptor) {
topLevelFields = append(topLevelFields, pf)

if gogoproto.HasTypeDecl(message.file.FileDescriptorProto, message.DescriptorProto) {
if !gogoproto.IsStdType(field) && !gogoproto.IsCustomType(field) && !gogoproto.IsCastType(field) {
if !gogoproto.IsStdType(field) && !gogoproto.IsCustomType(field) && !gogoproto.IsCastType(field) && !gogoproto.IsCastTypeWith(field) {
g.RecordTypeUse(field.GetTypeName())
}
} else {
// Even if the type does not need to be generated, we need to iterate
// over all its fields to be able to mark as used any imported types
// used by those fields.
for _, mfield := range message.Field {
if !gogoproto.IsStdType(mfield) && !gogoproto.IsCustomType(mfield) && !gogoproto.IsCastType(mfield) {
if !gogoproto.IsStdType(mfield) && !gogoproto.IsCustomType(mfield) && !gogoproto.IsCastType(mfield) && !gogoproto.IsCastTypeWith(mfield) {
g.RecordTypeUse(mfield.GetTypeName())
}
}
Expand Down
24 changes: 24 additions & 0 deletions protoc-gen-gogo/generator/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ package generator

import (
"bytes"
"fmt"
"go/parser"
"go/printer"
"go/token"
Expand Down Expand Up @@ -404,6 +405,29 @@ func getCustomType(field *descriptor.FieldDescriptorProto) (packageName string,
return "", "", err
}

func GetCastTypeWith(field *descriptor.FieldDescriptorProto) (string, string, string, string, error) {
return getCastTypeWith(field)
}

func getCastTypeWith(field *descriptor.FieldDescriptorProto) (casteePkg string, casteeTyp string, casterPkg string, casterTyp string, err error) {
if field.Options != nil {
var v interface{}
v, err = proto.GetExtension(field.Options, gogoproto.E_Casttypewith)
if err == nil && v.(*string) != nil {
all := *(v.(*string))
ctypes := strings.Split(all, ";")
if len(ctypes) != 2 {
err = fmt.Errorf("Bad casttypewith syntax %s", all)
return
}
casteePkg, casteeTyp = splitCPackageType(ctypes[0])
casterPkg, casterTyp = splitCPackageType(ctypes[1])
return
}
}
return
}

func splitCPackageType(ctype string) (packageName string, typ string) {
ss := strings.Split(ctype, ".")
if len(ss) == 1 {
Expand Down
Loading