Skip to content

Commit

Permalink
add field type support
Browse files Browse the repository at this point in the history
  • Loading branch information
Adphi committed Sep 27, 2021
1 parent b9c6d43 commit f607bf3
Show file tree
Hide file tree
Showing 18 changed files with 1,275 additions and 110 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/fatih/structtag v1.2.0
github.com/iancoleman/strcase v0.1.2 // indirect
github.com/lyft/protoc-gen-star v0.5.2 // indirect
github.com/stretchr/testify v1.7.0
golang.org/x/tools v0.1.6
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0
google.golang.org/protobuf v1.27.1
Expand Down
3 changes: 2 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ github.com/spf13/afero v1.3.4 h1:8q6vk3hthlpb2SouZcnBVKboxWQWMDNF38bwholZrJc=
github.com/spf13/afero v1.3.4/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down
202 changes: 202 additions & 0 deletions patch/field_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
package patch

import (
"go/ast"
"go/types"
"log"
"strings"

"golang.org/x/tools/go/ast/astutil"
)

func (p *Patcher) patchTypeDef(id *ast.Ident, obj types.Object) {
fieldType, ok := p.fieldTypes[obj]
if !ok {
return
}

parent := p.findParentNode(id)
if pkg, name, isSlice := packageAndName(fieldType); pkg != "" {
f := p.fileOf(id)
pkgImport := packageImport(pkg)
astutil.AddNamedImport(p.fset, f, pkgImport, pkg)
fieldType = pkgImport + "." + name
if isSlice {
fieldType = "[]" + fieldType
}
}
castDecl := func(v *ast.Field) bool {
switch t := v.Type.(type) {
case *ast.Ident:
t.Name = fieldType
return true
case *ast.ArrayType:
if isSliceType(fieldType) {
if id, ok := t.Elt.(*ast.Ident); ok {
id.Name = strings.TrimPrefix(fieldType, "[]")
return true
}
} else {
v.Type = &ast.Ident{
Name: fieldType,
}
return true
}
return false
default:
return false
}
}

// Cast Field definition
if id.Obj != nil && id.Obj.Decl != nil {
v, ok := id.Obj.Decl.(*ast.Field)
if !ok {
log.Printf("Warning: fieldType declared for non-field object: %v `%s`", obj, fieldType)
return
}
if !castDecl(v) {
log.Printf("Warning: unsupported fieldType type: %T `%s`", v.Type, fieldType)
}
return
}
switch obj.Type().(type) {
// Cast Getter signature
case *types.Signature:
n, ok := parent.(*ast.FuncDecl)
if !ok {
log.Printf("Warning: unexpected type for getter: %v `%T`", obj, parent)
break
}
if l := len(n.Type.Results.List); l != 1 {
log.Printf("Warning: unexpected return count for getter: %v `%d`", obj, l)
return
}
if !castDecl(n.Type.Results.List[0]) {
log.Printf("Warning: unsupported fieldType type: %T `%s`", n.Type.Results.List[0].Type, fieldType)
}
return
}
}

func (p *Patcher) patchTypeUsage(id *ast.Ident, obj types.Object) {
desiredType, ok := p.fieldTypes[obj]
if !ok {
return
}
var originalType string
switch t := obj.Type().(type) {
case *types.Basic:
originalType = t.Name()
case *types.Signature:
if t.Results().Len() != 1 {
return
}
originalType = t.Results().At(0).Type().String()
}
usageNode := p.findParentNode(id)
pkgPath, pkgName, isSlice := packageAndName(desiredType)
pkgImport := packageImport(pkgPath)
if pkgPath != "" {
desiredType = pkgImport + "." + pkgName
if isSlice {
desiredType = "[]" + desiredType
}
}
cast := func(as string, expr ast.Expr) ast.Expr {
if pkgPath != "" && as == desiredType {
f := p.fileOf(id)
// astutil.AddNamedImport already check for duplicated imports, so there is no need to do it here
astutil.AddNamedImport(p.fset, f, pkgImport, pkgPath)
}
return &ast.CallExpr{
Fun: &ast.Ident{
Name: as,
},
Args: []ast.Expr{expr},
}
}
parentNode := p.findParentNode(usageNode)

switch usage := usageNode.(type) {
case *ast.SelectorExpr:
switch parentExpr := parentNode.(type) {
case *ast.AssignStmt:
if len(parentExpr.Lhs) != len(parentExpr.Rhs) {
return
}
for i := range parentExpr.Lhs {
if parentExpr.Lhs[i] == usage {
parentExpr.Rhs[i] = cast(desiredType, parentExpr.Rhs[i])
return
}
}
for i := range parentExpr.Rhs {
if parentExpr.Rhs[i] == usage {
parentExpr.Rhs[i] = cast(originalType, parentExpr.Rhs[i])
return
}
}
case *ast.CallExpr:
parent := p.findParentNode(parentExpr)
assign, isAssign := parent.(*ast.AssignStmt)
if parentExpr.Fun == usage && isAssign {
for i := range assign.Rhs {
if assign.Rhs[i] == parentExpr {
assign.Rhs[i] = cast(originalType, assign.Rhs[i])
return
}
}
}
call, isCall := parent.(*ast.CallExpr)
if isCall {
for i := range call.Args {
if call.Args[i] == parentExpr {
call.Args[i] = cast(originalType, call.Args[i])
return
}
}
}
for i, v := range parentExpr.Args {
if v == usage {
parentExpr.Args[i] = cast(originalType, usage)
return
}
}
case *ast.BinaryExpr:
if parentExpr.X == usage {
parentExpr.X = cast(originalType, parentExpr.X)
}
if parentExpr.Y == usage {
parentExpr.Y = cast(originalType, parentExpr.Y)
}
}
case *ast.KeyValueExpr:
if usage.Key == id {
usage.Value = cast(desiredType, usage.Value)
return
}
if usage.Value == id {
usage.Value = cast(originalType, usage.Value)
return
}
}
}

func packageAndName(fqn string) (pkg string, name string, isSlice bool) {
isSlice = isSliceType(fqn)
fqn = strings.TrimPrefix(fqn, "[]")
parts := strings.Split(fqn, ".")
if len(parts) == 1 {
return "", fqn, isSlice
}
return strings.Join(parts[:len(parts)-1], "."), parts[len(parts)-1], isSlice
}

func isSliceType(typeName string) bool {
return strings.HasPrefix(typeName, "[]")
}

func packageImport(pkg string) string {
return strings.Replace(strings.Replace(pkg, "/", "_", -1), ".", "_", -1)
}
Loading

0 comments on commit f607bf3

Please sign in to comment.