Skip to content

Commit

Permalink
Merge pull request #9 from g4s8/refactor
Browse files Browse the repository at this point in the history
refactor: extract ast parser logic
  • Loading branch information
g4s8 authored Feb 13, 2024
2 parents 94290d4 + 4da51ae commit 2c49c74
Show file tree
Hide file tree
Showing 15 changed files with 634 additions and 461 deletions.
240 changes: 240 additions & 0 deletions ast.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package main

import (
"fmt"
"go/ast"
"go/doc"
"go/token"
"log"
"strings"
)

// visitor nodes types
const (
nodeUnknown int = iota
nodeType
nodeRoot
nodeStruct
nodeField
)

type visitorNode struct {
kind int
typeName string // type name if node is a type or field type name if node is a field
names []string // it's possible that a field has multiple names
doc string // field or type documentation or comment if doc is empty
children []*visitorNode // optional children nodes for structs
typeRef *visitorNode // type reference if field is a struct
tag string // field tag
isArray bool // true if field is an array
}

type (
astCommentsHandler func(*ast.Comment) bool
astTypeDocResolver func(*ast.TypeSpec) string
)

type astVisitor struct {
commentHandler astCommentsHandler
typeDocResolver astTypeDocResolver
logger *log.Logger

currentNode *visitorNode
pendingType bool // true if the next type is a target type
targetName string // name of the type we are looking for
depth int // current depth in the AST (used for debugging, 1 based)
}

func newAstVisitor(commentsHandler astCommentsHandler, typeDocsResolver astTypeDocResolver) *astVisitor {
return &astVisitor{
commentHandler: commentsHandler,
typeDocResolver: typeDocsResolver,
logger: logger(),
depth: 1,
}
}

func (v *astVisitor) push(node *visitorNode, appendChild bool) *astVisitor {
if appendChild {
v.currentNode.children = append(v.currentNode.children, node)
}
return &astVisitor{
commentHandler: v.commentHandler,
typeDocResolver: v.typeDocResolver,
logger: v.logger,
pendingType: v.pendingType,
currentNode: node,
depth: v.depth + 1,
}
}

func (v *astVisitor) Walk(n ast.Node) {
ast.Walk(v, n)
v.resolveFieldTypes()
}

func (v *astVisitor) Visit(n ast.Node) ast.Visitor {
if v.currentNode == nil {
v.currentNode = &visitorNode{kind: nodeRoot}
}

switch t := n.(type) {
case *ast.Comment:
v.logger.Printf("ast(%d): visit comment", v.depth)
if !v.pendingType {
v.pendingType = v.commentHandler(t)
}
return v
case *ast.TypeSpec:
v.logger.Printf("ast(%d): visit type: %q", v.depth, t.Name.Name)
doc := v.typeDocResolver(t)
name := t.Name.Name
if v.pendingType {
v.targetName = name
v.pendingType = false
v.logger.Printf("ast(%d): detect target type: %q", v.depth, name)
}
typeNode := &visitorNode{
names: []string{name},
typeName: name,
kind: nodeType,
doc: doc,
}
return v.push(typeNode, true)
case *ast.StructType:
v.logger.Printf("ast(%d): found struct", v.depth)
switch v.currentNode.kind {
case nodeType:
v.currentNode.kind = nodeStruct
return v
case nodeField:
structNode := &visitorNode{
kind: nodeStruct,
doc: v.currentNode.doc,
}
v.currentNode.typeRef = structNode
return v.push(structNode, false)
default:
panic(fmt.Sprintf("unexpected node kind: %d", v.currentNode.kind))
}
case *ast.Field:
names := fieldNamesToStr(t)
v.logger.Printf("ast(%d): visit field (%v)", v.depth, names)
doc := getFieldDoc(t)
var (
tag string
isArray bool
)
if t.Tag != nil {
tag = t.Tag.Value
}
if _, ok := t.Type.(*ast.ArrayType); ok {
isArray = true
}
fieldNode := &visitorNode{
kind: nodeField,
names: names,
doc: doc,
tag: tag,
isArray: isArray,
}
if expr, ok := t.Type.(*ast.Ident); ok {
fieldNode.typeName = expr.Name
}
return v.push(fieldNode, true)
}
return v
}

func (v *astVisitor) resolveFieldTypes() {
unresolved := getAllNodes(v.currentNode, func(n *visitorNode) bool {
return n.kind == nodeField && n.typeRef == nil
})
structs := getAllNodes(v.currentNode, func(n *visitorNode) bool {
return n.kind == nodeStruct
})
structsByName := make(map[string]*visitorNode, len(structs))
for _, s := range structs {
structsByName[s.typeName] = s
}
for _, f := range unresolved {
if s, ok := structsByName[f.typeName]; ok {
f.typeRef = s
v.logger.Printf("ast: resolve field type %q to struct %q", f.names, s.typeName)
}
}
}

func getAllNodes(root *visitorNode, filter func(*visitorNode) bool) []*visitorNode {
var result []*visitorNode
if filter(root) {
result = append(result, root)
}
for _, c := range root.children {
result = append(result, getAllNodes(c, filter)...)
}
return result
}

func getFieldDoc(f *ast.Field) string {
doc := f.Doc.Text()
if doc == "" {
doc = f.Comment.Text()
}
return strings.TrimSpace(doc)
}

func fieldNamesToStr(f *ast.Field) []string {
names := make([]string, len(f.Names))
for i, n := range f.Names {
names[i] = n.Name
}
return names
}

func newASTTypeDocResolver(fileSet *token.FileSet, astFile *ast.File) (func(t *ast.TypeSpec) string, error) {
docs, err := doc.NewFromFiles(fileSet, []*ast.File{astFile}, "./", doc.PreserveAST)
if err != nil {
return nil, fmt.Errorf("extract package docs: %w", err)
}
return func(t *ast.TypeSpec) string {
typeName := t.Name.String()
docStr := strings.TrimSpace(t.Doc.Text())
if docStr == "" {
for _, t := range docs.Types {
if t.Name == typeName {
docStr = strings.TrimSpace(t.Doc)
break
}
}
}
return docStr
}, nil
}

var astCommentDummyHandler = func(*ast.Comment) bool {
return false
}

func newASTCommentTargetLineHandler(goGenLine int, linePositions []int) func(*ast.Comment) bool {
l := logger()
return func(c *ast.Comment) bool {
// if type name is not specified we should process the next type
// declaration after the comment with go:generate
// which causes this command to be executed.
var line int
for l, pos := range linePositions {
if token.Pos(pos) > c.Pos() {
break
}
// $GOLINE env var is 1-based.
line = l + 1
}
if line != goGenLine {
return false
}

l.Printf("found go:generate comment at line %d", line)
return true
}
}
19 changes: 19 additions & 0 deletions ast_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package main

import (
"go/ast"
"go/token"
"testing"
)

func TestASTTypeDocResolver(t *testing.T) {
t.Run("Fail", func(t *testing.T) {
fset := token.NewFileSet()
astFile := ast.File{}
_, err := newASTTypeDocResolver(fset, &astFile)
if err == nil {
t.Errorf("Expected error, got nil")
}
t.Logf("Error: %v", err)
})
}
12 changes: 0 additions & 12 deletions debug.go

This file was deleted.

10 changes: 10 additions & 0 deletions generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ func TestOptions(t *testing.T) {
t.Fatal("expected fieldNames to be true")
}
})
t.Run("WithType", func(t *testing.T) {
const typeName = "Foo"
g, err := newGenerator("stub", 1, withType(typeName))
if err != nil {
t.Fatal("new generator error", err)
}
if g.targetType != typeName {
t.Fatalf("expected targetType to be %q, got %q", typeName, g.targetType)
}
})
t.Run("empty", func(t *testing.T) {
g, err := newGenerator("stub", 1)
if err != nil {
Expand Down
Loading

0 comments on commit 2c49c74

Please sign in to comment.