Skip to content

Commit

Permalink
add embedded message field support
Browse files Browse the repository at this point in the history
  • Loading branch information
Adphi committed Sep 23, 2021
1 parent f4fcce4 commit bacde2f
Show file tree
Hide file tree
Showing 11 changed files with 454 additions and 148 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
*.out
*.so
*.test
debug
.idea
5 changes: 5 additions & 0 deletions patch/go.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ message Options {
// For an enum value, this renames the generated Go const.
optional string name = 1;

// The embed option indicates the field should be embedded in the generated Go struct.
// Only message types can be embedded. Oneof fields cannot be embedded.
// See https://golang.org/ref/spec#Struct_types.
optional bool embed = 2;

// The getter option renames the generated getter method (default: Get<Field>)
// so a custom getter can be implemented in its place.
optional string getter = 10; // TODO: implement this
Expand Down
109 changes: 61 additions & 48 deletions patch/gopb/go.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 47 additions & 14 deletions patch/patcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Patcher struct {
objectRenames map[types.Object]string
tags map[protogen.GoIdent]string
fieldTags map[types.Object]string
embeds map[protogen.GoIdent]string
fieldEmbeds map[types.Object]string
}

// NewPatcher returns an initialized Patcher for gen.
Expand All @@ -63,6 +65,8 @@ func NewPatcher(gen *protogen.Plugin) (*Patcher, error) {
objectRenames: make(map[types.Object]string),
tags: make(map[protogen.GoIdent]string),
fieldTags: make(map[types.Object]string),
embeds: make(map[protogen.GoIdent]string),
fieldEmbeds: make(map[types.Object]string),
}
return p, p.scan()
}
Expand Down Expand Up @@ -243,7 +247,7 @@ func (p *Patcher) scanOneof(o *protogen.Oneof) {
newName = lint.Name(newName, lints.InitialismsMap())
}
if newName != "" {
p.RenameField(ident.WithChild(m.GoIdent, o.GoName), newName) // Oneof
p.RenameField(ident.WithChild(m.GoIdent, o.GoName), newName, false) // Oneof
p.RenameMethod(ident.WithChild(m.GoIdent, "Get"+o.GoName), "Get"+newName) // Getter
ifName := ident.WithPrefix(o.GoIdent, "is")
newIfName := "is" + p.nameFor(m.GoIdent) + "_" + newName
Expand All @@ -270,6 +274,24 @@ func (p *Patcher) scanField(f *protogen.Field) {
// Implicitly rename this oneof field because its parent(s) were renamed.
newName = f.GoName
}
// Embed field ?
embed := false
if opts.GetEmbed() {
switch {
case f.Message == nil:
log.Printf("Warning: embed declared for non-message field: %s", f.Desc.Name())
case f.Oneof != nil:
log.Printf("Warning: embed declared for oneof field: %s", f.Desc.Name())
default:
embed = true
// use the embed field message type's go name or rename option if defined
if mOpts := messageOptions(f.Message); mOpts.GetName() != "" {
newName = mOpts.GetName()
} else {
newName = f.Message.GoIdent.GoName
}
}
}
if lints.GetFields() || lints.GetAll() {
if newName == "" {
newName = f.GoName
Expand All @@ -278,12 +300,12 @@ func (p *Patcher) scanField(f *protogen.Field) {
}
if newName != "" {
if o != nil {
p.RenameType(f.GoIdent, p.nameFor(m.GoIdent)+"_"+newName) // Oneof wrapper struct
p.RenameField(ident.WithChild(f.GoIdent, f.GoName), newName) // Oneof wrapper field
p.RenameType(f.GoIdent, p.nameFor(m.GoIdent)+"_"+newName) // Oneof wrapper struct
p.RenameField(ident.WithChild(f.GoIdent, f.GoName), newName, false) // Oneof wrapper field (not embeddable)
ifName := ident.WithPrefix(o.GoIdent, "is")
p.RenameMethod(ident.WithChild(f.GoIdent, ifName.GoName), p.nameFor(ifName)) // Oneof interface method
} else {
p.RenameField(ident.WithChild(m.GoIdent, f.GoName), newName) // Field
p.RenameField(ident.WithChild(m.GoIdent, f.GoName), newName, embed) // Field
}
p.RenameMethod(ident.WithChild(m.GoIdent, "Get"+f.GoName), "Get"+newName) // Getter
}
Expand Down Expand Up @@ -344,9 +366,12 @@ func (p *Patcher) RenameValue(id protogen.GoIdent, newName string) {
// The id argument specifies a GoName from GoImportPath, e.g.: "github.com/org/repo/example".FooMessage.BarField
// newName should be the unqualified name (after the dot).
// The value of id.GoName should be the original generated identifier name, not a renamed identifier.
func (p *Patcher) RenameField(id protogen.GoIdent, newName string) {
func (p *Patcher) RenameField(id protogen.GoIdent, newName string, embed bool) {
p.renames[id] = newName
p.fieldRenames[id] = newName
if embed {
p.embeds[id] = newName
}
log.Printf("Rename field:\t%s.%s → %s", id.GoImportPath, id.GoName, newName)
}

Expand Down Expand Up @@ -510,6 +535,9 @@ func (p *Patcher) checkGoFiles() error {
continue
}
p.objectRenames[obj] = name
if _, ok := p.embeds[id]; ok {
p.fieldEmbeds[obj] = name
}
}

// Map struct tags.
Expand Down Expand Up @@ -648,7 +676,7 @@ func (p *Patcher) serializeGoFiles(res *pluginpb.CodeGeneratorResponse) error {
func (p *Patcher) patchGoFiles() error {
log.Printf("\nDefs")
for id, obj := range p.info.Defs {
p.patchIdent(id, obj)
p.patchIdent(id, obj, true)
p.patchTags(id, obj)
// if id.IsExported() {
// f := p.fset.File(id.NamePos)
Expand All @@ -658,27 +686,32 @@ func (p *Patcher) patchGoFiles() error {

log.Printf("\nUses\n")
for id, obj := range p.info.Uses {
p.patchIdent(id, obj)
p.patchIdent(id, obj, false)
}

log.Printf("\nUnresolved\n")
for _, f := range p.filesByName {
for _, id := range f.Unresolved {
p.patchIdent(id, nil)
p.patchIdent(id, nil, false)
}
}

return nil
}

func (p *Patcher) patchIdent(id *ast.Ident, obj types.Object) {
func (p *Patcher) patchIdent(id *ast.Ident, obj types.Object, isDecl bool) {
name := p.objectRenames[obj]
if name != "" {
p.patchComments(id, name)
id.Name = name
log.Printf("Renamed %s:\t%s → %s", typeString(obj), id.Name, name)
} else {
if name == "" {
// log.Printf("Unresolved:\t%v", id)
return
}
p.patchComments(id, name)
if _, ok := p.fieldEmbeds[obj]; ok && isDecl {
log.Printf("Renamed %s:\t%s → %s (embedded)", typeString(obj), id.Name, name)
id.Name = ""
} else {
log.Printf("Renamed %s:\t%s → %s", typeString(obj), id.Name, name)
id.Name = name
}
}

Expand Down
12 changes: 12 additions & 0 deletions tests/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ func ValidateEnum(t *testing.T, e protoreflect.Enum, names EnumNames, values Enu
// ValidateMessage performs basic validation of a message.
func ValidateMessage(t *testing.T, m proto.Message) {
// TODO: add some validation
b, err := proto.Marshal(m)
if err != nil {
t.Errorf("failed to marshal message: %v", err)
return
}
n := m.ProtoReflect().New().Interface()
if err := proto.Unmarshal(b, n); err != nil {
t.Errorf("failed to unmarshal message: %v", err)
}
if !proto.Equal(m, n) {
t.Errorf("marshal / unmarshal: expected %+v got %+v", m, n)
}
}

// ValidateTag performs basic validation of a struct tag.
Expand Down
Loading

0 comments on commit bacde2f

Please sign in to comment.