Skip to content

Commit

Permalink
Codegen for well-known types
Browse files Browse the repository at this point in the history
- Fixes gogo#472
  • Loading branch information
virtuald committed Sep 21, 2018
1 parent e14cafb commit 849e0c1
Show file tree
Hide file tree
Showing 19 changed files with 3,392 additions and 80 deletions.
2 changes: 2 additions & 0 deletions gogoproto/gogo.proto
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,6 @@ extend google.protobuf.FieldOptions {

optional bool stdtime = 65010;
optional bool stdduration = 65011;
optional bool wktpointer = 65012;

}
49 changes: 49 additions & 0 deletions gogoproto/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,55 @@ func IsStdDuration(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Stdduration, false)
}

func IsStdDouble(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.DoubleValue"
}

func IsStdFloat(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.FloatValue"
}

func IsStdInt64(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.Int64Value"
}

func IsStdUInt64(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.UInt64Value"
}

func IsStdInt32(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.Int32Value"
}

func IsStdUInt32(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.UInt32Value"
}

func IsStdBool(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.BoolValue"
}

func IsStdString(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.StringValue"
}

func IsStdBytes(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false) && *field.TypeName == ".google.protobuf.BytesValue"
}

func IsStdType(field *google_protobuf.FieldDescriptorProto) bool {
return (IsStdTime(field) || IsStdDuration(field) ||
IsStdDouble(field) || IsStdFloat(field) ||
IsStdInt64(field) || IsStdUInt64(field) ||
IsStdInt32(field) || IsStdUInt32(field) ||
IsStdBool(field) ||
IsStdString(field) || IsStdBytes(field))
}

func IsWktPtr(field *google_protobuf.FieldDescriptorProto) bool {
return proto.GetBoolExtension(field.Options, E_Wktpointer, false)
}

func NeedsNilCheck(proto3 bool, field *google_protobuf.FieldDescriptorProto) bool {
nullable := IsNullable(field)
if field.IsMessage() || IsCustomType(field) {
Expand Down
65 changes: 60 additions & 5 deletions plugin/equal/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,16 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
repeated := field.IsRepeated()
ctype := gogoproto.IsCustomType(field)
nullable := gogoproto.IsNullable(field)
isDuration := gogoproto.IsStdDuration(field)
isNormal := (gogoproto.IsStdDuration(field) ||
gogoproto.IsStdDouble(field) ||
gogoproto.IsStdFloat(field) ||
gogoproto.IsStdInt64(field) ||
gogoproto.IsStdUInt64(field) ||
gogoproto.IsStdInt32(field) ||
gogoproto.IsStdUInt32(field) ||
gogoproto.IsStdBool(field) ||
gogoproto.IsStdString(field))
isBytes := gogoproto.IsStdBytes(field)
isTimestamp := gogoproto.IsStdTime(field)
// oneof := field.OneofIndex != nil
if !repeated {
Expand Down Expand Up @@ -322,7 +331,7 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
}
p.Out()
p.P(`}`)
} else if isDuration {
} else if isNormal {
if nullable {
p.generateNullableField(fieldname, verbose)
} else {
Expand All @@ -336,6 +345,32 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
}
p.Out()
p.P(`}`)
} else if isBytes {
if nullable {
p.P(`if that1.`, fieldname, ` == nil {`)
p.In()
p.P(`if this.`, fieldname, ` != nil {`)
p.In()
if verbose {
p.P(`return `, p.fmtPkg.Use(), `.Errorf("this.`, fieldname, ` != nil && that1.`, fieldname, ` == nil")`)
} else {
p.P(`return false`)
}
p.Out()
p.P(`}`)
p.Out()
p.P(`} else if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `, that1.`, fieldname, `) {`)
} else {
p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `, that1.`, fieldname, `) {`)
}
p.In()
if verbose {
p.P(`return `, p.fmtPkg.Use(), `.Errorf("`, fieldname, ` this(%v) Not Equal that(%v)", this.`, fieldname, `, that1.`, fieldname, `)`)
} else {
p.P(`return false`)
}
p.Out()
p.P(`}`)
} else {
if field.IsMessage() || p.IsGroup(field) {
if nullable {
Expand Down Expand Up @@ -387,12 +422,14 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
} else {
p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) {`)
}
} else if isDuration {
} else if isNormal {
if nullable {
p.P(`if dthis, dthat := this.`, fieldname, `[i], that1.`, fieldname, `[i]; (dthis != nil && dthat != nil && *dthis != *dthat) || (dthis != nil && dthat == nil) || (dthis == nil && dthat != nil) {`)
} else {
p.P(`if this.`, fieldname, `[i] != that1.`, fieldname, `[i] {`)
}
} else if isBytes {
p.P(`if !`, p.bytesPkg.Use(), `.Equal(this.`, fieldname, `[i], that1.`, fieldname, `[i]) {`)
} else {
if p.IsMap(field) {
m := p.GoMapType(nil, field)
Expand All @@ -401,21 +438,39 @@ func (p *plugin) generateField(file *generator.FileDescriptor, message *generato
nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp)

mapValue := m.ValueAliasField
mapValueNormal := (gogoproto.IsStdDuration(mapValue) ||
gogoproto.IsStdDouble(mapValue) ||
gogoproto.IsStdFloat(mapValue) ||
gogoproto.IsStdInt64(mapValue) ||
gogoproto.IsStdUInt64(mapValue) ||
gogoproto.IsStdInt32(mapValue) ||
gogoproto.IsStdUInt32(mapValue) ||
gogoproto.IsStdBool(mapValue) ||
gogoproto.IsStdString(mapValue))
mapValueBytes := gogoproto.IsStdBytes(mapValue)
if mapValue.IsMessage() || p.IsGroup(mapValue) {
if nullable && valuegoTyp == valuegoAliasTyp {
p.P(`if !this.`, fieldname, `[i].Equal(that1.`, fieldname, `[i]) {`)
} else {
// Equal() has a pointer receiver, but map value is a value type
a := `this.` + fieldname + `[i]`
b := `that1.` + fieldname + `[i]`
if valuegoTyp != valuegoAliasTyp {
if !mapValueNormal && !mapValueBytes && valuegoTyp != valuegoAliasTyp {
// cast back to the type that has the generated methods on it
a = `(` + valuegoTyp + `)(` + a + `)`
b = `(` + valuegoTyp + `)(` + b + `)`
}
p.P(`a := `, a)
p.P(`b := `, b)
if nullable {
if mapValueNormal {
if nullable {
p.P(`if *a != *b {`)
} else {
p.P(`if a != b {`)
}
} else if mapValueBytes {
p.P(`if !`, p.bytesPkg.Use(), `.Equal(a, b) {`)
} else if nullable {
p.P(`if !a.Equal(b) {`)
} else {
p.P(`if !(&a).Equal(&b) {`)
Expand Down
2 changes: 1 addition & 1 deletion plugin/gostring/gostring.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (p *gostring) Generate(file *generator.FileDescriptor) {
p.P(`s = append(s, "`, fieldname, `: " + `, mapName, `+ ",\n")`)
p.Out()
p.P(`}`)
} else if (field.IsMessage() && !gogoproto.IsCustomType(field) && !gogoproto.IsStdTime(field) && !gogoproto.IsStdDuration(field)) || p.IsGroup(field) {
} else if (field.IsMessage() && !gogoproto.IsCustomType(field) && !gogoproto.IsStdType(field)) || p.IsGroup(field) {
if nullable || repeated {
p.P(`if this.`, fieldname, ` != nil {`)
p.In()
Expand Down
150 changes: 148 additions & 2 deletions plugin/marshalto/marshalto.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,33 @@ func (p *marshalto) mapField(numGen NumGen, field *descriptor.FieldDescriptorPro
} else if gogoproto.IsStdDuration(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdDouble(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdDouble(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDoubleMarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdFloat(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdFloat(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdFloatMarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdInt64(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdInt64(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdInt64MarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdUInt64(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdUInt64(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdUInt64MarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdInt32(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdInt32(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdInt32MarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdUInt32(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdUInt32(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdUInt32MarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdBool(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdBool(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdBoolMarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdString(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdString(*`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdStringMarshalTo(*`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdBytes(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdBytes(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdBytesMarshalTo(`, varName, `, dAtA[i:])`)
} else if protoSizer {
p.callVarint(varName, `.ProtoSize()`)
p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`)
Expand Down Expand Up @@ -781,8 +808,7 @@ func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.Fi
sum = append(sum, `soz`+p.localName+`(uint64(v))`)
case descriptor.FieldDescriptorProto_TYPE_MESSAGE:
if valuegoTyp != valuegoAliasTyp &&
!gogoproto.IsStdTime(field) &&
!gogoproto.IsStdDuration(field) {
!gogoproto.IsStdType(field) {
if nullable {
// cast back to the type that has the generated methods on it
accessor = `((` + valuegoTyp + `)(` + accessor + `))`
Expand All @@ -799,6 +825,24 @@ func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.Fi
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdTime(*`, accessor, `)`)
} else if gogoproto.IsStdDuration(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdDuration(*`, accessor, `)`)
} else if gogoproto.IsStdDouble(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdDouble(*`, accessor, `)`)
} else if gogoproto.IsStdFloat(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdFloat(*`, accessor, `)`)
} else if gogoproto.IsStdInt64(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdInt64(*`, accessor, `)`)
} else if gogoproto.IsStdUInt64(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdUInt64(*`, accessor, `)`)
} else if gogoproto.IsStdInt32(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdInt32(*`, accessor, `)`)
} else if gogoproto.IsStdUInt32(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdUInt32(*`, accessor, `)`)
} else if gogoproto.IsStdBool(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdBool(*`, accessor, `)`)
} else if gogoproto.IsStdString(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdString(*`, accessor, `)`)
} else if gogoproto.IsStdBytes(field) {
p.P(`msgSize = `, p.typesPkg.Use(), `.SizeOfStdBytes(`, accessor, `)`)
} else if protoSizer {
p.P(`msgSize = `, accessor, `.ProtoSize()`)
} else {
Expand Down Expand Up @@ -852,6 +896,57 @@ func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.Fi
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdDouble(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdDouble(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdDoubleMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdFloat(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdFloat(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdFloatMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdInt64(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdInt64(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdInt64MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdUInt64(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdUInt64(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdUInt64MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdInt32(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdInt32(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdInt32MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdUInt32(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdUInt32(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdUInt32MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdBool(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdBool(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdBoolMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdString(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdString(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdStringMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdBytes(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdBytes(`, varName, `)`)
p.P(`n, err := `, p.typesPkg.Use(), `.StdBytesMarshalTo(`, varName, `, dAtA[i:])`)
} else if protoSizer {
p.callVarint(varName, ".ProtoSize()")
p.P(`n, err := `, varName, `.MarshalTo(dAtA[i:])`)
Expand Down Expand Up @@ -882,6 +977,57 @@ func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.Fi
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdDuration(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDurationMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdDouble(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdDouble(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdDoubleMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdFloat(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdFloat(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdFloatMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdInt64(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdInt64(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdInt64MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdUInt64(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdUInt64(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdUInt64MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdInt32(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdInt32(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdInt32MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdUInt32(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdUInt32(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdUInt32MarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdBool(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdBool(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdBoolMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdString(field) {
if gogoproto.IsNullable(field) {
varName = "*" + varName
}
p.callVarint(p.typesPkg.Use(), `.SizeOfStdString(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdStringMarshalTo(`, varName, `, dAtA[i:])`)
} else if gogoproto.IsStdBytes(field) {
p.callVarint(p.typesPkg.Use(), `.SizeOfStdBytes(`, varName, `)`)
p.P(`n`, numGen.Next(), `, err := `, p.typesPkg.Use(), `.StdBytesMarshalTo(`, varName, `, dAtA[i:])`)
} else if protoSizer {
p.callVarint(varName, `.ProtoSize()`)
p.P(`n`, numGen.Next(), `, err := `, varName, `.MarshalTo(dAtA[i:])`)
Expand Down
Loading

0 comments on commit 849e0c1

Please sign in to comment.