Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ parser:
make -C go/vt/sqlparser

visitor:
go run ./go/tools/asthelpergen -in ./go/vt/sqlparser -iface vitess.io/vitess/go/vt/sqlparser.SQLNode -except "*ColName"
go run ./go/tools/asthelpergen/main -in ./go/vt/sqlparser -iface vitess.io/vitess/go/vt/sqlparser.SQLNode -except "*ColName"

sizegen:
go run go/tools/sizegen/sizegen.go \
Expand Down
43 changes: 6 additions & 37 deletions go/tools/asthelpergen/asthelpergen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package main
package asthelpergen

import (
"bytes"
"flag"
"fmt"
"go/types"
"io/ioutil"
"log"
"path"
"strings"

Expand Down Expand Up @@ -170,48 +168,19 @@ func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) {
return result, nil
}

type typePaths []string
// TypePaths are the packages
type TypePaths []string

func (t *typePaths) String() string {
func (t *TypePaths) String() string {
return fmt.Sprintf("%v", *t)
}

func (t *typePaths) Set(path string) error {
// Set adds the package path
func (t *TypePaths) Set(path string) error {
*t = append(*t, path)
return nil
}

func main() {
var patterns typePaths
var generate, except string
var verify bool

flag.Var(&patterns, "in", "Go packages to load the generator")
flag.StringVar(&generate, "iface", "", "Root interface generate rewriter for")
flag.BoolVar(&verify, "verify", false, "ensure that the generated files are correct")
flag.StringVar(&except, "except", "", "don't deep clone these types")
flag.Parse()

result, err := GenerateASTHelpers(patterns, generate, except)
if err != nil {
log.Fatal(err)
}

if verify {
for _, err := range VerifyFilesOnDisk(result) {
log.Fatal(err)
}
log.Printf("%d files OK", len(result))
} else {
for fullPath, file := range result {
if err := file.Save(fullPath); err != nil {
log.Fatalf("failed to save file to '%s': %v", fullPath, err)
}
log.Printf("saved '%s'", fullPath)
}
}
}

// VerifyFilesOnDisk compares the generated results from the codegen against the files that
// currently exist on disk and returns any mismatches
func VerifyFilesOnDisk(result map[string]*jen.File) (errors []error) {
Expand Down
2 changes: 1 addition & 1 deletion go/tools/asthelpergen/asthelpergen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package main
package asthelpergen

import (
"fmt"
Expand Down
99 changes: 62 additions & 37 deletions go/tools/asthelpergen/clone_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package main
package asthelpergen

import (
"fmt"
"go/types"

"vitess.io/vitess/go/vt/log"
"log"
"strings"

"github.com/dave/jennifer/jen"
)
Expand Down Expand Up @@ -61,8 +61,22 @@ func (c *cloneGen) visitInterface(t types.Type, _ *types.Interface) error {

const cloneName = "Clone"

func (c *cloneGen) addFunc(name string, code jen.Code) {
c.methods = append(c.methods, jen.Comment(name+" creates a deep clone of the input."), code)
type methodType int

const (
clone methodType = iota
equals
)

func (c *cloneGen) addFunc(name string, typ methodType, code jen.Code) {
var comment string
switch typ {
case clone:
comment = " creates a deep clone of the input."
case equals:
comment = " does deep equals between the two objects."
}
c.methods = append(c.methods, jen.Comment(name+comment), code)
}

// readValueOfType produces code to read the expression of type `t`, and adds the type to the todo-list
Expand All @@ -81,10 +95,10 @@ func (c *cloneGen) readValueOfType(t types.Type, expr jen.Code) jen.Code {
}

func (c *cloneGen) makeStructCloneMethod(t types.Type) error {
receiveType := types.TypeString(t, noQualifier)
funcName := "Clone" + printableTypeName(t)
c.addFunc(funcName,
jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType).Block(
typeString := types.TypeString(t, noQualifier)
funcName := cloneName + printableTypeName(t)
c.addFunc(funcName, clone,
jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block(
jen.Return(jen.Op("*").Add(c.readValueOfType(types.NewPointer(t), jen.Op("&").Id("n")))),
))
return nil
Expand All @@ -95,7 +109,7 @@ func (c *cloneGen) makeSliceCloneMethod(t types.Type, slice *types.Slice) error
name := printableTypeName(t)
funcName := cloneName + name

c.addFunc(funcName,
c.addFunc(funcName, clone,
//func (n Bytes) Clone() Bytes {
jen.Func().Id(funcName).Call(jen.Id("n").Id(typeString)).Id(typeString).Block(
// res := make(Bytes, len(n))
Expand Down Expand Up @@ -161,7 +175,7 @@ func (c *cloneGen) makeInterfaceCloneMethod(t types.Type, iface *types.Interface
}

default:
log.Errorf("unexpected type encountered: %s", typeString)
log.Fatalf("unexpected type encountered: %s", typeString)
}

return nil
Expand All @@ -180,22 +194,20 @@ func (c *cloneGen) makeInterfaceCloneMethod(t types.Type, iface *types.Interface

funcName := cloneName + typeName
funcDecl := jen.Func().Id(funcName).Call(jen.Id("in").Id(typeString)).Id(typeString).Block(stmts...)
c.addFunc(funcName, funcDecl)
c.addFunc(funcName, clone, funcDecl)
return nil
}

func (c *cloneGen) makePtrCloneMethod(t types.Type, ptr *types.Pointer) error {
func (c *cloneGen) makePtrCloneMethod(t types.Type, ptr *types.Pointer) {
receiveType := types.TypeString(t, noQualifier)

funcName := "Clone" + printableTypeName(t)
c.addFunc(funcName,
c.addFunc(funcName, clone,
jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType).Block(
ifNilReturnNil("n"),
jen.Id("out").Op(":=").Add(c.readValueOfType(ptr.Elem(), jen.Op("*").Id("n"))),
jen.Return(jen.Op("&").Id("out")),
))

return nil
}

func (c *cloneGen) createFile(pkgName string) (string, *jen.File) {
Expand All @@ -221,7 +233,7 @@ func (c *cloneGen) createFile(pkgName string) (string, *jen.File) {
continue
}

log.Errorf("don't know how to handle %s %T", typeName, underlying)
log.Fatalf("don't know how to handle %s %T", typeName, underlying)
}

for _, method := range c.methods {
Expand All @@ -241,14 +253,16 @@ func isBasic(t types.Type) bool {
}

func (c *cloneGen) tryStruct(underlying, t types.Type) bool {
_, ok := underlying.(*types.Struct)
strct, ok := underlying.(*types.Struct)
if !ok {
return false
}

err := c.makeStructCloneMethod(t)
if err != nil {
panic(err) // todo
if err := c.makeStructCloneMethod(t); err != nil {
log.Fatalf("%v", err)
}
if err := c.makeStructEqualsMethod(t, strct); err != nil {
log.Fatalf("%v", err)
}
return true
}
Expand All @@ -258,27 +272,33 @@ func (c *cloneGen) tryPtr(underlying, t types.Type) bool {
return false
}

if strct, isStruct := ptr.Elem().Underlying().(*types.Struct); isStruct {
c.makePtrToStructCloneMethod(t, strct)
ptrToType := ptr.Elem().Underlying()

switch ptrToType := ptrToType.(type) {
case *types.Struct:
c.makePtrToStructCloneMethod(t, ptrToType)
c.makePtrToStructEqualsMethod(t, ptrToType)
return true
case *types.Basic:
c.makePtrToBasicEqualsMethod(t)
c.makePtrCloneMethod(t, ptr)
return true
default:
c.makePtrCloneMethod(t, ptr)
}

err := c.makePtrCloneMethod(t, ptr)
if err != nil {
panic(err) // todo
}
return true
}

func (c *cloneGen) makePtrToStructCloneMethod(t types.Type, strct *types.Struct) {
receiveType := types.TypeString(t, noQualifier)
funcName := "Clone" + printableTypeName(t)
funcName := cloneName + printableTypeName(t)

//func CloneRefOfType(n *Type) *Type
funcDeclaration := jen.Func().Id(funcName).Call(jen.Id("n").Id(receiveType)).Id(receiveType)

if receiveType == c.exceptType {
c.addFunc(funcName, funcDeclaration.Block(
c.addFunc(funcName, clone, funcDeclaration.Block(
jen.Return(jen.Id("n")),
))
return
Expand Down Expand Up @@ -310,7 +330,7 @@ func (c *cloneGen) makePtrToStructCloneMethod(t types.Type, strct *types.Struct)
jen.Return(jen.Op("&").Id("out")),
)

c.addFunc(funcName,
c.addFunc(funcName, clone,
funcDeclaration.Block(stmts...),
)
}
Expand All @@ -321,9 +341,12 @@ func (c *cloneGen) tryInterface(underlying, t types.Type) bool {
return false
}

err := c.makeInterfaceCloneMethod(t, iface)
if err != nil {
panic(err) // todo
if err := c.makeInterfaceCloneMethod(t, iface); err != nil {
log.Fatalf("%v", err)
}

if err := c.makeInterfaceEqualsMethod(t, iface); err != nil {
log.Fatalf("%v", err)
}
return true
}
Expand All @@ -334,9 +357,11 @@ func (c *cloneGen) trySlice(underlying, t types.Type) bool {
return false
}

err := c.makeSliceCloneMethod(t, slice)
if err != nil {
panic(err) // todo
if err := c.makeSliceCloneMethod(t, slice); err != nil {
log.Fatalf("%v", err)
}
if err := c.makeSliceEqualsMethod(t, slice); err != nil {
log.Fatalf("%v", err)
}
return true
}
Expand All @@ -351,7 +376,7 @@ func printableTypeName(t types.Type) string {
case *types.Named:
return t.Obj().Name()
case *types.Basic:
return t.Name()
return strings.Title(t.Name())
case *types.Interface:
return t.String()
default:
Expand Down
Loading