Skip to content

Commit

Permalink
bake: refactor to use iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidGamba committed Jun 9, 2024
1 parent cdbcea4 commit 4f467eb
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 149 deletions.
195 changes: 92 additions & 103 deletions bake/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func GetFuncDeclForPackage(dir string, m *map[string]FuncDecl) error {
// opt.String("hola", "mundo")
// return func(ctx context.Context, opt *getoptions.GetOpt, args []string) error {
func PrintAst(dir string) error {

cfg := &packages.Config{Mode: packages.NeedFiles | packages.NeedSyntax, Dir: dir}
pkgs, err := packages.Load(cfg, ".")
if err != nil {
Expand Down Expand Up @@ -165,121 +166,109 @@ func PrintAst(dir string) error {
// opt.String("hola", "mundo")
// return func(ctx context.Context, opt *getoptions.GetOpt, args []string) error {
func ListAst(dir string) error {
cfg := &packages.Config{Mode: packages.NeedFiles | packages.NeedSyntax, Dir: dir}
pkgs, err := packages.Load(cfg, ".")
if err != nil {
return fmt.Errorf("failed to load packages: %w", err)
}
for _, pkg := range pkgs {
// Logger.Println(pkg.ID, pkg.GoFiles)
for _, file := range pkg.GoFiles {
// Logger.Printf("file: %s\n", file)
// parse file
fset := token.NewFileSet()
fset.AddFile(file, fset.Base(), len(file))
f, err := parser.ParseFile(fset, file, nil, parser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse file: %w", err)
}
// Iterate through every node in the file
ast.Inspect(f, func(n ast.Node) bool {
switch x := n.(type) {
// Check function declarations for exported functions
case *ast.FuncDecl:
if x.Name.IsExported() {
name := x.Name.Name
description := x.Doc.Text()
var buf bytes.Buffer
printer.Fprint(&buf, fset, x.Type)
Logger.Printf("file: %s\n", file)
Logger.Printf("type: %s, name: %s, desc: %s\n", buf.String(), name, strings.TrimSpace(description))
for p, err := range parsedFiles(dir) {
if err != nil {
return err
}

// Iterate through every node in the file
ast.Inspect(p.f, func(n ast.Node) bool {
switch x := n.(type) {
// Check function declarations for exported functions
case *ast.FuncDecl:
if x.Name.IsExported() {
name := x.Name.Name
description := x.Doc.Text()
var buf bytes.Buffer
printer.Fprint(&buf, p.fset, x.Type)
Logger.Printf("file: %s\n", p.file)
Logger.Printf("type: %s, name: %s, desc: %s\n", buf.String(), name, strings.TrimSpace(description))

// Check Params
// Expect opt *getoptions.GetOpt
if len(x.Type.Params.List) != 1 {
// Check Params
// Expect opt *getoptions.GetOpt
if len(x.Type.Params.List) != 1 {
return false
}
var optFieldName string
for _, param := range x.Type.Params.List {
name := param.Names[0].Name
var buf bytes.Buffer
printer.Fprint(&buf, p.fset, param.Type)
Logger.Printf("name: %s, %s\n", name, buf.String())
if buf.String() != "*getoptions.GetOpt" {
return false
}
var optFieldName string
for _, param := range x.Type.Params.List {
name := param.Names[0].Name
var buf bytes.Buffer
printer.Fprint(&buf, fset, param.Type)
Logger.Printf("name: %s, %s\n", name, buf.String())
if buf.String() != "*getoptions.GetOpt" {
return false
}
optFieldName = name
}
optFieldName = name
}

// Check Results
// Expect getoptions.CommandFn
if len(x.Type.Results.List) != 1 {
// Check Results
// Expect getoptions.CommandFn
if len(x.Type.Results.List) != 1 {
return false
}
for _, result := range x.Type.Results.List {
var buf bytes.Buffer
printer.Fprint(&buf, p.fset, result.Type)
Logger.Printf("result: %s\n", buf.String())
if buf.String() != "getoptions.CommandFn" {
return false
}
for _, result := range x.Type.Results.List {
var buf bytes.Buffer
printer.Fprint(&buf, fset, result.Type)
Logger.Printf("result: %s\n", buf.String())
if buf.String() != "getoptions.CommandFn" {
return false
}
}

// Check for Expressions of opt type
ast.Inspect(n, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.BlockStmt:
for i, stmt := range x.List {
var buf bytes.Buffer
printer.Fprint(&buf, fset, stmt)
// We are expecting the expression before the return function
_, ok := stmt.(*ast.ReturnStmt)
if ok {
return false
}
Logger.Printf("i: %d\n", i)
Logger.Printf("stmt: %s\n", buf.String())
exprStmt, ok := stmt.(*ast.ExprStmt)
if !ok {
continue
}
// spew.Dump(exprStmt)

// Check for CallExpr
ast.Inspect(exprStmt, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.CallExpr:
fun, ok := x.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
xIdent, ok := fun.X.(*ast.Ident)
if !ok {
return false
}
if xIdent.Name != optFieldName {
return false
}
Logger.Printf("handling %s.%s\n", xIdent.Name, fun.Sel.Name)
}

switch fun.Sel.Name {
case "String":
handleString(optFieldName, n)
}
// Check for Expressions of opt type
ast.Inspect(n, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.BlockStmt:
for i, stmt := range x.List {
var buf bytes.Buffer
printer.Fprint(&buf, p.fset, stmt)
// We are expecting the expression before the return function
_, ok := stmt.(*ast.ReturnStmt)
if ok {
return false
}
Logger.Printf("i: %d\n", i)
Logger.Printf("stmt: %s\n", buf.String())
exprStmt, ok := stmt.(*ast.ExprStmt)
if !ok {
continue
}
// spew.Dump(exprStmt)

// Check for CallExpr
ast.Inspect(exprStmt, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.CallExpr:
fun, ok := x.Fun.(*ast.SelectorExpr)
if !ok {
return false
}
return true
})
}
xIdent, ok := fun.X.(*ast.Ident)
if !ok {
return false
}
if xIdent.Name != optFieldName {
return false
}
Logger.Printf("handling %s.%s\n", xIdent.Name, fun.Sel.Name)

switch fun.Sel.Name {
case "String":
handleString(optFieldName, n)
}

return false
}
return true
})
}
return true
})
}
}
return true
})
}
return true
})
}
}
return true
})
}
return nil
}
Expand Down
53 changes: 18 additions & 35 deletions bake/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"io"
"log"
"os"
"path/filepath"
"reflect"
"strings"
"unicode"

Expand All @@ -31,31 +29,34 @@ func program(args []string) int {
opt.SetUnknownMode(getoptions.Pass)
opt.Bool("quiet", false, opt.GetEnv("QUIET"))

bakefile, plug, err := loadPlugin(ctx)
dir, err := findBakeDir(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: %s\n", err)
return 1
}

err = loadOptFns(ctx, plug, opt, filepath.Dir(bakefile))
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: %s\n", err)
return 1
}
// bakefile, plug, err := loadPlugin(ctx)
// if err != nil {
// fmt.Fprintf(os.Stderr, "ERROR: %s\n", err)
// return 1
// }

b := opt.NewCommand("_bake", "")
// err = loadOptFns(ctx, plug, opt, filepath.Dir(bakefile))
// if err != nil {
// fmt.Fprintf(os.Stderr, "ERROR: %s\n", err)
// return 1
// }

bls := b.NewCommand("list-symbols", "lists symbols")
bls.SetCommandFn(ListSymbolsRun(bakefile))
b := opt.NewCommand("_bake", "")

bld := b.NewCommand("list-descriptions", "lists descriptions")
bld.SetCommandFn(ListDescriptionsRun(bakefile))
bld.SetCommandFn(ListDescriptionsRun(dir))

bast := b.NewCommand("show-ast", "show raw-ish ast")
bast.SetCommandFn(ShowASTRun(bakefile))
bast.SetCommandFn(ShowASTRun(dir))

bastList := b.NewCommand("list-ast", "list parsed ast")
bastList.SetCommandFn(ListASTRun(bakefile))
bastList.SetCommandFn(ListASTRun(dir))

opt.HelpCommand("help", opt.Alias("?"))
remaining, err := opt.Parse(args[1:])
Expand All @@ -64,18 +65,6 @@ func program(args []string) int {
return 1
}
if opt.Called("quiet") {
logger, err := plug.Lookup("Logger")
if err == nil {
var l **log.Logger
l, ok := logger.(*(*log.Logger))
if ok {
(*l).SetOutput(io.Discard)
} else {
Logger.Printf("failed to convert Logger: %s\n", reflect.TypeOf(logger))
}
} else {
Logger.Printf("failed to find Logger\n")
}
Logger.SetOutput(io.Discard)
}

Expand Down Expand Up @@ -156,10 +145,8 @@ func camelToKebab(camel string) string {
return buffer.String()
}

func ListDescriptionsRun(bakefile string) getoptions.CommandFn {
func ListDescriptionsRun(dir string) getoptions.CommandFn {
return func(ctx context.Context, opt *getoptions.GetOpt, args []string) error {
Logger.Printf("bakefile: %s\n", bakefile)
dir := filepath.Dir(bakefile)
m := make(map[string]FuncDecl)
err := GetFuncDeclForPackage(dir, &m)
if err != nil {
Expand All @@ -173,10 +160,8 @@ func ListDescriptionsRun(bakefile string) getoptions.CommandFn {
}
}

func ShowASTRun(bakefile string) getoptions.CommandFn {
func ShowASTRun(dir string) getoptions.CommandFn {
return func(ctx context.Context, opt *getoptions.GetOpt, args []string) error {
Logger.Printf("bakefile: %s\n", bakefile)
dir := filepath.Dir(bakefile)
err := PrintAst(dir)
if err != nil {
return fmt.Errorf("failed to inspect package: %w", err)
Expand All @@ -185,10 +170,8 @@ func ShowASTRun(bakefile string) getoptions.CommandFn {
}
}

func ListASTRun(bakefile string) getoptions.CommandFn {
func ListASTRun(dir string) getoptions.CommandFn {
return func(ctx context.Context, opt *getoptions.GetOpt, args []string) error {
Logger.Printf("bakefile: %s\n", bakefile)
dir := filepath.Dir(bakefile)
err := ListAst(dir)
if err != nil {
return fmt.Errorf("failed to inspect package: %w", err)
Expand Down
11 changes: 0 additions & 11 deletions bake/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,6 @@ func loadOptFns(ctx context.Context, plug *plugin.Plugin, opt *getoptions.GetOpt
return nil
}

func ListSymbolsRun(bakefile string) getoptions.CommandFn {
return func(ctx context.Context, opt *getoptions.GetOpt, args []string) error {
plug, err := plugin.Open(bakefile)
if err != nil {
return fmt.Errorf("failed to open plugin: %w", err)
}
inspectPlugin(plug)
return nil
}
}

// https://github.com/golang/go/issues/17823
type Plug struct {
pluginpath string
Expand Down
47 changes: 47 additions & 0 deletions bake/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"iter"

"golang.org/x/tools/go/packages"
)

type parsedFile struct {
file string
fset *token.FileSet
f *ast.File
}

// Requires GOEXPERIMENT=rangefunc
func parsedFiles(dir string) iter.Seq2[parsedFile, error] {
return func(yield func(parsedFile, error) bool) {
cfg := &packages.Config{Mode: packages.NeedFiles | packages.NeedSyntax, Dir: dir}
pkgs, err := packages.Load(cfg, ".")
if err != nil {
yield(parsedFile{}, fmt.Errorf("failed to load packages: %w", err))
return
}
for _, pkg := range pkgs {
// Logger.Println(pkg.ID, pkg.GoFiles)
for _, file := range pkg.GoFiles {
p := parsedFile{}
// Logger.Printf("file: %s\n", file)
// parse file
fset := token.NewFileSet()
fset.AddFile(file, fset.Base(), len(file))
p.file = file
p.fset = fset
f, err := parser.ParseFile(fset, file, nil, parser.ParseComments)
if err != nil {
yield(p, fmt.Errorf("failed to parse file: %w", err))
}
p.f = f
yield(p, nil)
}
}
}
}

0 comments on commit 4f467eb

Please sign in to comment.