diff --git a/ast.go b/ast.go new file mode 100644 index 0000000..95c09d6 --- /dev/null +++ b/ast.go @@ -0,0 +1,240 @@ +package main + +import ( + "fmt" + "go/ast" + "go/doc" + "go/token" + "log" + "strings" +) + +// visitor nodes types +const ( + nodeUnknown int = iota + nodeType + nodeRoot + nodeStruct + nodeField +) + +type visitorNode struct { + kind int + typeName string // type name if node is a type or field type name if node is a field + names []string // it's possible that a field has multiple names + doc string // field or type documentation or comment if doc is empty + children []*visitorNode // optional children nodes for structs + typeRef *visitorNode // type reference if field is a struct + tag string // field tag + isArray bool // true if field is an array +} + +type ( + astCommentsHandler func(*ast.Comment) bool + astTypeDocResolver func(*ast.TypeSpec) string +) + +type astVisitor struct { + commentHandler astCommentsHandler + typeDocResolver astTypeDocResolver + logger *log.Logger + + currentNode *visitorNode + pendingType bool // true if the next type is a target type + targetName string // name of the type we are looking for + depth int // current depth in the AST (used for debugging, 1 based) +} + +func newAstVisitor(commentsHandler astCommentsHandler, typeDocsResolver astTypeDocResolver) *astVisitor { + return &astVisitor{ + commentHandler: commentsHandler, + typeDocResolver: typeDocsResolver, + logger: logger(), + depth: 1, + } +} + +func (v *astVisitor) push(node *visitorNode, appendChild bool) *astVisitor { + if appendChild { + v.currentNode.children = append(v.currentNode.children, node) + } + return &astVisitor{ + commentHandler: v.commentHandler, + typeDocResolver: v.typeDocResolver, + logger: v.logger, + pendingType: v.pendingType, + currentNode: node, + depth: v.depth + 1, + } +} + +func (v *astVisitor) Walk(n ast.Node) { + ast.Walk(v, n) + v.resolveFieldTypes() +} + +func (v *astVisitor) Visit(n ast.Node) ast.Visitor { + if v.currentNode == nil { + v.currentNode = &visitorNode{kind: nodeRoot} + } + + switch t := n.(type) { + case *ast.Comment: + v.logger.Printf("ast(%d): visit comment", v.depth) + if !v.pendingType { + v.pendingType = v.commentHandler(t) + } + return v + case *ast.TypeSpec: + v.logger.Printf("ast(%d): visit type: %q", v.depth, t.Name.Name) + doc := v.typeDocResolver(t) + name := t.Name.Name + if v.pendingType { + v.targetName = name + v.pendingType = false + v.logger.Printf("ast(%d): detect target type: %q", v.depth, name) + } + typeNode := &visitorNode{ + names: []string{name}, + typeName: name, + kind: nodeType, + doc: doc, + } + return v.push(typeNode, true) + case *ast.StructType: + v.logger.Printf("ast(%d): found struct", v.depth) + switch v.currentNode.kind { + case nodeType: + v.currentNode.kind = nodeStruct + return v + case nodeField: + structNode := &visitorNode{ + kind: nodeStruct, + doc: v.currentNode.doc, + } + v.currentNode.typeRef = structNode + return v.push(structNode, false) + default: + panic(fmt.Sprintf("unexpected node kind: %d", v.currentNode.kind)) + } + case *ast.Field: + names := fieldNamesToStr(t) + v.logger.Printf("ast(%d): visit field (%v)", v.depth, names) + doc := getFieldDoc(t) + var ( + tag string + isArray bool + ) + if t.Tag != nil { + tag = t.Tag.Value + } + if _, ok := t.Type.(*ast.ArrayType); ok { + isArray = true + } + fieldNode := &visitorNode{ + kind: nodeField, + names: names, + doc: doc, + tag: tag, + isArray: isArray, + } + if expr, ok := t.Type.(*ast.Ident); ok { + fieldNode.typeName = expr.Name + } + return v.push(fieldNode, true) + } + return v +} + +func (v *astVisitor) resolveFieldTypes() { + unresolved := getAllNodes(v.currentNode, func(n *visitorNode) bool { + return n.kind == nodeField && n.typeRef == nil + }) + structs := getAllNodes(v.currentNode, func(n *visitorNode) bool { + return n.kind == nodeStruct + }) + structsByName := make(map[string]*visitorNode, len(structs)) + for _, s := range structs { + structsByName[s.typeName] = s + } + for _, f := range unresolved { + if s, ok := structsByName[f.typeName]; ok { + f.typeRef = s + v.logger.Printf("ast: resolve field type %q to struct %q", f.names, s.typeName) + } + } +} + +func getAllNodes(root *visitorNode, filter func(*visitorNode) bool) []*visitorNode { + var result []*visitorNode + if filter(root) { + result = append(result, root) + } + for _, c := range root.children { + result = append(result, getAllNodes(c, filter)...) + } + return result +} + +func getFieldDoc(f *ast.Field) string { + doc := f.Doc.Text() + if doc == "" { + doc = f.Comment.Text() + } + return strings.TrimSpace(doc) +} + +func fieldNamesToStr(f *ast.Field) []string { + names := make([]string, len(f.Names)) + for i, n := range f.Names { + names[i] = n.Name + } + return names +} + +func newASTTypeDocResolver(fileSet *token.FileSet, astFile *ast.File) (func(t *ast.TypeSpec) string, error) { + docs, err := doc.NewFromFiles(fileSet, []*ast.File{astFile}, "./", doc.PreserveAST) + if err != nil { + return nil, fmt.Errorf("extract package docs: %w", err) + } + return func(t *ast.TypeSpec) string { + typeName := t.Name.String() + docStr := strings.TrimSpace(t.Doc.Text()) + if docStr == "" { + for _, t := range docs.Types { + if t.Name == typeName { + docStr = strings.TrimSpace(t.Doc) + break + } + } + } + return docStr + }, nil +} + +var astCommentDummyHandler = func(*ast.Comment) bool { + return false +} + +func newASTCommentTargetLineHandler(goGenLine int, linePositions []int) func(*ast.Comment) bool { + l := logger() + return func(c *ast.Comment) bool { + // if type name is not specified we should process the next type + // declaration after the comment with go:generate + // which causes this command to be executed. + var line int + for l, pos := range linePositions { + if token.Pos(pos) > c.Pos() { + break + } + // $GOLINE env var is 1-based. + line = l + 1 + } + if line != goGenLine { + return false + } + + l.Printf("found go:generate comment at line %d", line) + return true + } +} diff --git a/ast_test.go b/ast_test.go new file mode 100644 index 0000000..ce441b6 --- /dev/null +++ b/ast_test.go @@ -0,0 +1,19 @@ +package main + +import ( + "go/ast" + "go/token" + "testing" +) + +func TestASTTypeDocResolver(t *testing.T) { + t.Run("Fail", func(t *testing.T) { + fset := token.NewFileSet() + astFile := ast.File{} + _, err := newASTTypeDocResolver(fset, &astFile) + if err == nil { + t.Errorf("Expected error, got nil") + } + t.Logf("Error: %v", err) + }) +} diff --git a/debug.go b/debug.go deleted file mode 100644 index dc02cba..0000000 --- a/debug.go +++ /dev/null @@ -1,12 +0,0 @@ -package main - -import "fmt" - -const debugLogs = false - -func debug(f string, args ...any) { - if !debugLogs { - return - } - fmt.Printf("DEBUG: "+f+"\n", args...) -} diff --git a/generator_test.go b/generator_test.go index fb9de93..2594c15 100644 --- a/generator_test.go +++ b/generator_test.go @@ -70,6 +70,16 @@ func TestOptions(t *testing.T) { t.Fatal("expected fieldNames to be true") } }) + t.Run("WithType", func(t *testing.T) { + const typeName = "Foo" + g, err := newGenerator("stub", 1, withType(typeName)) + if err != nil { + t.Fatal("new generator error", err) + } + if g.targetType != typeName { + t.Fatalf("expected targetType to be %q, got %q", typeName, g.targetType) + } + }) t.Run("empty", func(t *testing.T) { g, err := newGenerator("stub", 1) if err != nil { diff --git a/inspector.go b/inspector.go index d60631e..99bae29 100644 --- a/inspector.go +++ b/inspector.go @@ -2,292 +2,163 @@ package main import ( "fmt" - "go/ast" - "go/doc" "go/parser" "go/token" + "log" "strings" ) -type envFieldKind int - -const ( - envFieldKindPlain envFieldKind = iota - envFieldKindStruct // struct reference -) - -type envField struct { - name string - kind envFieldKind - doc string - opts EnvVarOptions - typeRef string - fieldName string - envPrefix string -} - -type envStruct struct { - name string - doc string - fields []envField -} - -type anonymousStruct struct { - name string // generated name - doc *ast.CommentGroup - comments *ast.CommentGroup -} - type inspector struct { typeName string // type name to generate documentation for, could be empty all bool // generate documentation for all types in the file execLine int // line number of the go:generate directive useFieldNames bool // use field names if tag is not specified - - fileSet *token.FileSet - lines []int - pendingType bool - items []*envStruct - anonymousStructs map[[2]token.Pos]anonymousStruct // map of anonymous structs by token position - doc *doc.Package - err error + log *log.Logger } func newInspector(typeName string, all bool, execLine int, useFieldNames bool) *inspector { return &inspector{ - typeName: typeName, - all: all, - execLine: execLine, - useFieldNames: useFieldNames, - anonymousStructs: make(map[[2]token.Pos]anonymousStruct), + typeName: typeName, + all: all, + execLine: execLine, + useFieldNames: useFieldNames, + log: logger(), } } func (i *inspector) inspectFile(fileName string) ([]*EnvScope, error) { - i.fileSet = token.NewFileSet() - file, err := parser.ParseFile(i.fileSet, fileName, nil, parser.ParseComments) + fileSet := token.NewFileSet() + file, err := parser.ParseFile(fileSet, fileName, nil, parser.ParseComments) if err != nil { return nil, fmt.Errorf("parse file: %w", err) } - // get a lines to position map for the file. - f := i.fileSet.File(file.Pos()) - i.lines = f.Lines() - return i.inspect(file) -} - -func (i *inspector) inspect(node ast.Node) ([]*EnvScope, error) { - i.items = make([]*envStruct, 0) - ast.Walk(i, node) - if i.err != nil { - return nil, i.err - } - scopes, err := i.buildScopes() + docResolver, err := newASTTypeDocResolver(fileSet, file) if err != nil { - return nil, fmt.Errorf("build scopes: %w", err) + return nil, fmt.Errorf("new ast type doc resolver: %w", err) } - return scopes, nil -} - -func (i *inspector) getStruct(t *ast.TypeSpec) *envStruct { - typeName := t.Name.Name - for _, s := range i.items { - if s.name == typeName { - return s - } + var commentsHandler astCommentsHandler + if i.all { + commentsHandler = astCommentDummyHandler + } else { + commentsHandler = newASTCommentTargetLineHandler(i.execLine, fileSet.File(file.Pos()).Lines()) } - - s := i.parseType(t) - i.items = append(i.items, s) - return s -} - -func (i *inspector) Visit(n ast.Node) ast.Visitor { - if i.err != nil { - return nil + visitor := newAstVisitor(commentsHandler, docResolver) + visitor.Walk(file) + targetName := i.typeName + if targetName == "" { + targetName = visitor.targetName } + return i.traverseAST(visitor.currentNode, targetName), nil +} - switch t := n.(type) { - case *ast.File: - var err error - i.doc, err = doc.NewFromFiles(i.fileSet, []*ast.File{t}, "./", doc.PreserveAST) - if err != nil { - i.err = fmt.Errorf("parse package doc: %w", err) - return nil - } - case *ast.Comment: - // if type name is not specified we should process the next type - // declaration after the comment with go:generate - // which causes this command to be executed. - if i.typeName != "" || i.all { - return i - } - if !t.Pos().IsValid() { - return i - } - var line int - for l, pos := range i.lines { - if token.Pos(pos) > t.Pos() { - break - } - // $GOLINE env var is 1-based. - line = l + 1 - } - if line != i.execLine { - return i +func (i *inspector) traverseAST(root *visitorNode, targetName string) []*EnvScope { + scopes := make([]*EnvScope, 0, len(root.children)) + logger := logger() + for _, child := range root.children { + if child.kind != nodeType && child.kind != nodeStruct { + panic(fmt.Sprintf("expected type node root child, got %v", child.kind)) } - i.pendingType = true - return i - case *ast.TypeSpec: - debug("type spec: %s (%T) (%d-%d)", t.Name.Name, t.Type, t.Pos(), t.End()) - if i.typeName == "" && i.pendingType { - i.typeName = t.Name.Name + if !i.all && targetName != child.typeName { + logger.Printf("inspector: (traverse) skipping node: %v", child.typeName) + continue } + logger.Printf("inspector: (traverse) process node: %v", child.typeName) - if st, ok := t.Type.(*ast.StructType); ok { - i.processStruct(t, st) - } - // reset pending type flag event if this type - // is not processable (e.g. interface type). - i.pendingType = false - case *ast.StructType: - posRange := [2]token.Pos{t.Pos(), t.End()} - as, ok := i.anonymousStructs[posRange] - if !ok { - return i - } - typeSpec := &ast.TypeSpec{ - Name: &ast.Ident{Name: as.name}, - Doc: as.doc, - Comment: as.comments, + if scope := newScope(child, i.useFieldNames); scope != nil { + scopes = append(scopes, scope) } - i.processStruct(typeSpec, t) - - debug("struct type: %T (%d-%d)", t, t.Pos(), t.End()) } - return i + return scopes } -func (i *inspector) processStruct(t *ast.TypeSpec, st *ast.StructType) { - str := i.getStruct(t) - debug("parsing struct %s", str.name) - for _, field := range st.Fields.List { - items := i.parseField(field) - if len(items) == 0 { - continue - } - str.fields = append(str.fields, items...) +func newScope(node *visitorNode, useFieldNames bool) *EnvScope { + if len(node.names) != 1 { + panic("type node must have exactly one name") } -} -func (i *inspector) parseType(t *ast.TypeSpec) *envStruct { - typeName := t.Name.Name - docStr := strings.TrimSpace(t.Doc.Text()) - if docStr == "" { - for _, t := range i.doc.Types { - if t.Name == typeName { - docStr = strings.TrimSpace(t.Doc) - break - } - } - } - return &envStruct{ - name: typeName, - doc: docStr, - } -} + logger := logger() + logger.Printf("inspecctor: (scope) got node: %v", node.names) -func getTagValues(tag, tagName string) []string { - tagPrefix := tagName + ":" - if !strings.Contains(tag, tagPrefix) { - return nil + scope := &EnvScope{ + Name: node.names[0], + Doc: node.doc, } - tagValue := strings.Split(tag, tagPrefix)[1] - leftQ := strings.Index(tagValue, `"`) - if leftQ == -1 || leftQ == len(tagValue)-1 { - return nil + for _, child := range node.children { + if items := newDocItems(child, useFieldNames, ""); len(items) > 0 { + logger.Printf("inspector: (scope) add items: %d", len(items)) + scope.Vars = append(scope.Vars, items...) + } else { + logger.Printf("inspector: (scope) no items") + } } - rightQ := strings.Index(tagValue[leftQ+1:], `"`) - if rightQ == -1 { + if len(scope.Vars) == 0 { return nil } - tagValue = tagValue[leftQ+1 : leftQ+rightQ+1] - return strings.Split(tagValue, ",") + return scope } -func (i *inspector) parseField(f *ast.Field) (out []envField) { - if f.Tag == nil && !i.useFieldNames { - return - } - - var tag string - if t := f.Tag; t != nil { - tag = t.Value - } - - envPrefix := getTagValues(tag, "envPrefix") - if len(envPrefix) > 0 && envPrefix[0] != "" { - var item envField - item.envPrefix = envPrefix[0] - item.kind = envFieldKindStruct - switch fieldType := f.Type.(type) { - case *ast.Ident: - item.typeRef = fieldType.Name - case *ast.StructType: - nameGen := fastRandString(16) - typeSpec := &ast.TypeSpec{ - Name: &ast.Ident{Name: nameGen}, - Type: fieldType, - Doc: f.Doc, +func newDocItems(node *visitorNode, useFieldNames bool, envPrefix string) []*EnvDocItem { + logger := logger() + builder := new(envDocItemsBuilder).apply( + withEnvDocItemEnvPrefix(envPrefix), + withEnvDocItemDoc(node.doc), + ) + logger.Printf("inspector: (items) process node: %v, envPrefix=%q", node.names, envPrefix) + if node.kind == nodeField && node.typeRef != nil { + if tags := getTagValues(node.tag, "envPrefix"); len(tags) > 0 { + envPrefix = strConcat(envPrefix, tags[0]) + } + logger.Printf("inspector: (items) get subitem fields for typeref: %q, envPrefix=%q", node.typeRef.names, envPrefix) + typeRef := node.typeRef + builder.apply(withEnvDocItemDoc(typeRef.doc), withEnvDocEmptyNames) + for _, subItem := range node.typeRef.children { + logger.Printf("inspector: (items) add subitem for typeref %q: %q", node.typeRef.names, subItem.names) + if items := newDocItems(subItem, useFieldNames, envPrefix); len(items) > 0 { + builder.apply(withEnvDocItemAddChildren(items)) } - i.getStruct(typeSpec) - item.typeRef = nameGen - posRange := [2]token.Pos{fieldType.Pos(), fieldType.End()} - i.anonymousStructs[posRange] = anonymousStruct{ - name: nameGen, - doc: f.Doc, - comments: f.Comment, - } - debug("anonymous struct found: %s (%d-%d)", nameGen, f.Type.Pos(), f.Type.End()) - - default: - panic(fmt.Sprintf("unsupported field type: %T", f.Type)) - } - fieldNames := make([]string, len(f.Names)) - for i, name := range f.Names { - fieldNames[i] = name.Name } - item.fieldName = strings.Join(fieldNames, ", ") - out = []envField{item} - return + debugBuilder(logger, "inspector: (items) typeref builder: ", builder) + return builder.items() } - if !strings.Contains(tag, "env:") && !i.useFieldNames { - return + if node.tag == "" && !useFieldNames { + logger.Printf("inspector: (items) no tag and no field names, skip node: %q", node.names) + return nil } - tagValues := getTagValues(tag, "env") - if len(tagValues) > 0 && tagValues[0] != "" { - var item envField - item.name = tagValues[0] - item.kind = envFieldKindPlain - out = []envField{item} - } else if i.useFieldNames { - out = make([]envField, len(f.Names)) - for i, name := range f.Names { - out[i].name = camelToSnake(name.Name) - out[i].kind = envFieldKindPlain + tagName, opts := parseEnvTag(node.tag) + if tagName != "" { + logger.Printf("inspector: (items) tag name: %q", tagName) + builder.apply(withEnvDocItemNames(tagName)) + } else if useFieldNames { + logger.Printf("inspector: (items) field names: %q", node.names) + names := make([]string, len(node.names)) + for i, name := range node.names { + names[i] = camelToSnake(name) } + builder.apply(withEnvDocItemNames(names...)) } else { - return + logger.Printf("inspector: (items) no tag name and not using field names") + return nil } - docStr := strings.TrimSpace(f.Doc.Text()) - if docStr == "" { - docStr = strings.TrimSpace(f.Comment.Text()) + // Check if the field type is a slice or array, then use default separator + if node.isArray && opts.Separator == "" { + opts.Separator = "," } - for i := range out { - out[i].doc = docStr + + builder.apply(withEnvDocItemOpts(opts)) + + debugBuilder(logger, "inspector: (items) builder: ", builder) + return builder.items() +} + +func parseEnvTag(tag string) (string, EnvVarOptions) { + tagValues := getTagValues(tag, "env") + var tagName string + if len(tagValues) > 0 { + tagName = tagValues[0] } var opts EnvVarOptions @@ -316,87 +187,24 @@ func (i *inspector) parseField(f *ast.Field) (out []envField) { if len(envSeparator) > 0 { opts.Separator = envSeparator[0] } - // Check if the field type is a slice or array - if _, ok := f.Type.(*ast.ArrayType); ok && opts.Separator == "" { - opts.Separator = "," - } - for i := range out { - out[i].opts = opts - } - return + return tagName, opts } -func (i *inspector) buildScopes() ([]*EnvScope, error) { - scopes := make([]*EnvScope, 0, len(i.items)) - for _, s := range i.items { - if !i.all && s.name != i.typeName { - debug("skip %q", s.name) - continue - } - var isAnonymous bool - for _, f := range i.anonymousStructs { - if f.name == s.name { - isAnonymous = true - break - } - } - if isAnonymous { - debug("skip anonymous struct %q", s.name) - continue - } - - debug("process %q", s.name) - scope := &EnvScope{ - Name: s.name, - Doc: s.doc, - } - for _, f := range s.fields { - item, err := i.buildItem(&f, "") - if err != nil { - return nil, err - } - scope.Vars = append(scope.Vars, item) - } - scopes = append(scopes, scope) +func getTagValues(tag, tagName string) []string { + tagPrefix := tagName + ":" + if !strings.Contains(tag, tagPrefix) { + return nil } - return scopes, nil -} - -func (i *inspector) buildItem(f *envField, envPrefix string) (EnvDocItem, error) { - switch f.kind { - case envFieldKindPlain: - return EnvDocItem{ - Name: fmt.Sprintf("%s%s", envPrefix, f.name), - Doc: f.doc, - Opts: f.opts, - debugName: f.name, - }, nil - case envFieldKindStruct: - envPrefix := fmt.Sprintf("%s%s", envPrefix, f.envPrefix) - var base *envStruct - for _, s := range i.items { - if s.name == f.typeRef { - base = s - break - } - } - if base == nil { - return EnvDocItem{}, fmt.Errorf("struct %q not found", f.typeRef) - } - parentItem := EnvDocItem{ - Doc: base.doc, - debugName: base.name, - } - for _, f := range base.fields { - item, err := i.buildItem(&f, envPrefix) - if err != nil { - return EnvDocItem{}, fmt.Errorf("build item `%s`: %w", f.name, err) - } - parentItem.Children = append(parentItem.Children, item) - } - return parentItem, nil - default: - panic("unknown field kind") + tagValue := strings.Split(tag, tagPrefix)[1] + leftQ := strings.Index(tagValue, `"`) + if leftQ == -1 || leftQ == len(tagValue)-1 { + return nil } + rightQ := strings.Index(tagValue[leftQ+1:], `"`) + if rightQ == -1 { + return nil + } + tagValue = tagValue[leftQ+1 : leftQ+rightQ+1] + return strings.Split(tagValue, ",") } diff --git a/inspector_test.go b/inspector_test.go index 38c936a..fd8bec2 100644 --- a/inspector_test.go +++ b/inspector_test.go @@ -2,9 +2,7 @@ package main import ( "embed" - "errors" "fmt" - "go/ast" "io" "os" "path" @@ -13,137 +11,65 @@ import ( func TestTagParsers(t *testing.T) { type testCase struct { - tag string - names []string - useFieldNames bool - expect EnvDocItem - expectList []EnvDocItem - fail bool + tag string + expectName string + expectOpts EnvVarOptions } for i, c := range []testCase{ - {tag: "", fail: true}, - {tag: " ", fail: true}, - {tag: `env:"FOO"`, expect: EnvDocItem{Name: "FOO"}}, - {tag: ` env:FOO `, fail: true}, - {tag: `json:"bar" env:"FOO" qwe:"baz"`, expect: EnvDocItem{Name: "FOO"}}, - {tag: `env:"SECRET,file"`, expect: EnvDocItem{Name: "SECRET", Opts: EnvVarOptions{FromFile: true}}}, + {tag: ""}, + {tag: " "}, + {tag: `env:"FOO"`, expectName: "FOO"}, + {tag: ` env:FOO `}, + {tag: `json:"bar" env:"FOO" qwe:"baz"`, expectName: "FOO"}, + {tag: `env:"SECRET,file"`, expectName: "SECRET", expectOpts: EnvVarOptions{FromFile: true}}, { - tag: `env:"PASSWORD,file" envDefault:"/tmp/password" json:"password"`, - expect: EnvDocItem{Name: "PASSWORD", Opts: EnvVarOptions{FromFile: true, Default: "/tmp/password"}}, + tag: `env:"PASSWORD,file" envDefault:"/tmp/password" json:"password"`, + expectName: "PASSWORD", + expectOpts: EnvVarOptions{FromFile: true, Default: "/tmp/password"}, }, { - tag: `env:"CERTIFICATE,file,expand" envDefault:"${CERTIFICATE_FILE}"`, - expect: EnvDocItem{ - Name: "CERTIFICATE", Opts: EnvVarOptions{ - FromFile: true, Expand: true, Default: "${CERTIFICATE_FILE}", - }, + tag: `env:"CERTIFICATE,file,expand" envDefault:"${CERTIFICATE_FILE}"`, + expectName: "CERTIFICATE", + expectOpts: EnvVarOptions{ + FromFile: true, Expand: true, Default: "${CERTIFICATE_FILE}", }, }, { - tag: `env:"SECRET_KEY,required" json:"secret_key"`, - expect: EnvDocItem{Name: "SECRET_KEY", Opts: EnvVarOptions{Required: true}}, + tag: `env:"SECRET_KEY,required" json:"secret_key"`, + expectName: "SECRET_KEY", + expectOpts: EnvVarOptions{Required: true}, }, { - tag: `json:"secret_val" env:"SECRET_VAL,notEmpty"`, - expect: EnvDocItem{Name: "SECRET_VAL", Opts: EnvVarOptions{Required: true, NonEmpty: true}}, + tag: `json:"secret_val" env:"SECRET_VAL,notEmpty"`, + expectName: "SECRET_VAL", + expectOpts: EnvVarOptions{Required: true, NonEmpty: true}, }, { - tag: `fooo:"1" env:"JUST_A_MESS,required,notEmpty,file,expand" json:"just_a_mess" envDefault:"${JUST_A_MESS_FILE}" bar:"2"`, - expect: EnvDocItem{ - Name: "JUST_A_MESS", - Opts: EnvVarOptions{ - Required: true, NonEmpty: true, FromFile: true, Expand: true, - Default: "${JUST_A_MESS_FILE}", - }, + tag: `fooo:"1" env:"JUST_A_MESS,required,notEmpty,file,expand" json:"just_a_mess" envDefault:"${JUST_A_MESS_FILE}" bar:"2"`, + expectName: "JUST_A_MESS", + expectOpts: EnvVarOptions{ + Required: true, NonEmpty: true, FromFile: true, Expand: true, + Default: "${JUST_A_MESS_FILE}", }, }, { - tag: `env:"WORDS" envSeparator:";"`, - expect: EnvDocItem{ - Name: "WORDS", - Opts: EnvVarOptions{Separator: ";"}, - }, - }, - { - names: []string{"Foo", "BarBaz"}, - expectList: []EnvDocItem{ - {Name: "FOO"}, - {Name: "BAR_BAZ"}, - }, - useFieldNames: true, - }, - { - names: []string{"Foo"}, - tag: `env:",required"`, - expectList: []EnvDocItem{ - {Name: "FOO", Opts: EnvVarOptions{Required: true}}, - }, - useFieldNames: true, + tag: `env:"WORDS" envSeparator:";"`, + expectName: "WORDS", + expectOpts: EnvVarOptions{Separator: ";"}, }, } { t.Run(fmt.Sprint(i), func(t *testing.T) { - fieldNames := make([]*ast.Ident, len(c.names)) - for i, name := range c.names { - fieldNames[i] = &ast.Ident{Name: name} - } - var tag *ast.BasicLit - if c.tag != "" { - tag = &ast.BasicLit{Value: c.tag} - } - field := &ast.Field{ - Tag: tag, - Names: fieldNames, - } - - i := inspector{ - useFieldNames: c.useFieldNames, + name, opts := parseEnvTag(c.tag) + if e, a := c.expectName, name; e != a { + t.Errorf("expected[%d] name %q, got %q", i, e, a) } - - expect := c.expectList - if len(expect) == 0 && c.expect.Name != "" { - expect = []EnvDocItem{c.expect} - } - - actual := i.parseField(field) - if c.fail { - if actual != nil { - t.Errorf("expected nil, got %#v", actual) - } - return - } - if len(expect) != len(actual) { - t.Errorf("expected %d items, got %d", len(expect), len(actual)) - } - for i, e := range expect { - a := actual[i] - if e.Name != a.name { - t.Errorf("expected[%d] name %q, got %q", i, e.Name, a.name) - } - if e.Doc != a.doc { - t.Errorf("expected[%d] doc %q, got %q", i, e.Doc, a.doc) - } - if e.Opts != a.opts { - t.Errorf("expected[%d] opts %#v, got %#v", i, e.Opts, a.opts) - } + if e, a := c.expectOpts, opts; e != a { + t.Errorf("expected[%d] opts %#v, got %#v", i, e, a) } }) } } -func TestInspectorError(t *testing.T) { - sourceFile := path.Join(t.TempDir(), "tmp.go") - if err := copyTestFile(path.Join("testdata", "type.go"), sourceFile); err != nil { - t.Fatal("Copy test file data", err) - } - insp := newInspector("", true, 0, false) - targetErr := errors.New("target error") - insp.err = targetErr - _, err := insp.inspectFile(sourceFile) - if err != targetErr { - t.Errorf("Expected error %q, got %q", targetErr, err) - } -} - //go:embed testdata var testdata embed.FS @@ -157,13 +83,13 @@ func TestInspector(t *testing.T) { typeName string goLine int all bool - expect []EnvDocItem - expectScopes []EnvScope + expect []*EnvDocItem + expectScopes []*EnvScope }{ { name: "go_generate.go", goLine: 3, - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Name: "FOO", Doc: "Foo stub", @@ -173,7 +99,7 @@ func TestInspector(t *testing.T) { { name: "tags.go", typeName: "Type1", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Name: "SECRET", Doc: "Secret is a secret value that is read from a file.", @@ -207,7 +133,7 @@ func TestInspector(t *testing.T) { { name: "type.go", typeName: "Type1", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Name: "FOO", Doc: "Foo stub", @@ -217,7 +143,7 @@ func TestInspector(t *testing.T) { { name: "arrays.go", typeName: "Arrays", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Name: "DOT_SEPARATED", Doc: "DotSeparated stub", @@ -233,7 +159,7 @@ func TestInspector(t *testing.T) { { name: "comments.go", typeName: "Comments", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Name: "FOO", Doc: "Foo stub", @@ -247,10 +173,10 @@ func TestInspector(t *testing.T) { { name: "all.go", all: true, - expectScopes: []EnvScope{ + expectScopes: []*EnvScope{ { Name: "Foo", - Vars: []EnvDocItem{ + Vars: []*EnvDocItem{ { Name: "ONE", Doc: "One is a one.", @@ -263,7 +189,7 @@ func TestInspector(t *testing.T) { }, { Name: "Bar", - Vars: []EnvDocItem{ + Vars: []*EnvDocItem{ { Name: "THREE", Doc: "Three is a three.", @@ -279,11 +205,11 @@ func TestInspector(t *testing.T) { { name: "envprefix.go", typeName: "Settings", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Doc: "Database is the database settings.", debugName: "Database", - Children: []EnvDocItem{ + Children: []*EnvDocItem{ { Name: "DB_PORT", Doc: "Port is the port to connect to", @@ -311,7 +237,7 @@ func TestInspector(t *testing.T) { { Doc: "ServerConfig is the server settings.", debugName: "Server", - Children: []EnvDocItem{ + Children: []*EnvDocItem{ { Name: "SERVER_PORT", Doc: "Port is the port to listen on", @@ -325,7 +251,7 @@ func TestInspector(t *testing.T) { { Doc: "TimeoutConfig is the timeout settings.", debugName: "Timeout", - Children: []EnvDocItem{ + Children: []*EnvDocItem{ { Name: "SERVER_TIMEOUT_READ", Doc: "Read is the read timeout", @@ -349,10 +275,10 @@ func TestInspector(t *testing.T) { { name: "anonymous.go", typeName: "Config", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { Doc: "Repo is the configuration for the repository.", - Children: []EnvDocItem{ + Children: []*EnvDocItem{ { Name: "REPO_CONN", Doc: "Conn is the connection string for the repository.", @@ -365,9 +291,9 @@ func TestInspector(t *testing.T) { { name: "nodocs.go", typeName: "Config", - expect: []EnvDocItem{ + expect: []*EnvDocItem{ { - Children: []EnvDocItem{ + Children: []*EnvDocItem{ { Name: "REPO_CONN", Opts: EnvVarOptions{Required: true, NonEmpty: true}, @@ -379,7 +305,7 @@ func TestInspector(t *testing.T) { } { scopes := c.expectScopes if scopes == nil { - scopes = []EnvScope{ + scopes = []*EnvScope{ { Name: c.typeName, Vars: c.expect, @@ -409,8 +335,9 @@ func copyTestFile(name string, dest string) error { return nil } -func inspectorTester(name string, typeName string, all bool, lineN int, expect []EnvScope) func(*testing.T) { +func inspectorTester(name string, typeName string, all bool, lineN int, expect []*EnvScope) func(*testing.T) { return func(t *testing.T) { + t.Logf("inspect name=%q typeName=%q all=%v lineN=%d", name, typeName, all, lineN) sourceFile := path.Join(t.TempDir(), "tmp.go") if err := copyTestFile(path.Join("testdata", name), sourceFile); err != nil { t.Fatal("Copy test file data", err) @@ -442,7 +369,7 @@ func inspectorTester(name string, typeName string, all bool, lineN int, expect [ } } -func testScopeVar(t *testing.T, logPrefix string, expect, actual EnvDocItem) { +func testScopeVar(t *testing.T, logPrefix string, expect, actual *EnvDocItem) { t.Helper() if expect.Name != actual.Name { diff --git a/log.go b/log.go new file mode 100644 index 0000000..912843b --- /dev/null +++ b/log.go @@ -0,0 +1,26 @@ +package main + +import ( + "log" + "os" +) + +type nullWriter struct{} + +func (nullWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +var ( + nullLogger = log.New(nullWriter{}, "", 0) + debugLogger = log.New(os.Stdout, "DEBUG: ", log.Ldate|log.Ltime) +) + +var debugLogs = false + +func logger() *log.Logger { + if debugLogs { + return debugLogger + } + return nullLogger +} diff --git a/log_test.go b/log_test.go new file mode 100644 index 0000000..5e3250e --- /dev/null +++ b/log_test.go @@ -0,0 +1,21 @@ +package main + +import "testing" + +func TestLog(t *testing.T) { + flag := debugLogs + t.Cleanup(func() { + debugLogs = flag + }) + + if debugLogs { + t.Fatalf("Expected debugLogs to be false, got %v", debugLogs) + } + if l := logger(); l != nullLogger { + t.Fatalf("Expected nil logger, got %v", l) + } + debugLogs = true + if l := logger(); l != debugLogger { + t.Fatalf("Expected debug logger, got %v", l) + } +} diff --git a/main.go b/main.go index 5bcbe9d..9d724d2 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ type appConfig struct { envPrefix string noStyles bool fieldNames bool + debug bool } func (cfg *appConfig) parseFlags(f *flag.FlagSet) error { @@ -27,6 +28,8 @@ func (cfg *appConfig) parseFlags(f *flag.FlagSet) error { f.StringVar(&cfg.envPrefix, "env-prefix", "", "Environment variable prefix") f.BoolVar(&cfg.noStyles, "no-styles", false, "Disable styles in html output") f.BoolVar(&cfg.fieldNames, "field-names", false, "Use field names if tag is not specified") + f.BoolVar(&cfg.debug, "debug", false, "Enable debug mode") + if err := f.Parse(os.Args[1:]); err != nil { return fmt.Errorf("parsing CLI args: %w", err) } @@ -61,6 +64,9 @@ func main() { if err != nil { fatal(err) } + if cfg.debug { + debugLogs = true + } if err := run(&cfg); err != nil { fatal("Generate error:", err) } diff --git a/main_test.go b/main_test.go index 69d66b5..1f333b7 100644 --- a/main_test.go +++ b/main_test.go @@ -60,7 +60,17 @@ func TestConfig(t *testing.T) { _ = getTestConfig(t, true) }) t.Run("bad-env", func(t *testing.T) { - t.Setenv("GOFILE", "") + os.Args = []string{ + "cmd", + "-output", "test.md", + "-type", "test", + "-no-styles", + "-format", "markdown", + "-env-prefix", "TEST_", + "-field-names", + "-all", + } + t.Setenv("GOFILE", "test.go") t.Setenv("GOLINE", "abc") _ = getTestConfig(t, true) }) @@ -116,7 +126,6 @@ func TestMainRun(t *testing.T) { } outputFile := path.Join(t.TempDir(), "example.md") config := appConfig{ - typeName: "Type1", formatName: "markdown", outputFileName: outputFile, inputFileName: inputFile, @@ -130,6 +139,26 @@ func TestMainRun(t *testing.T) { t.Fatal("run", err) } }) + t.Run("with-type", func(t *testing.T) { + inputFile := path.Join(t.TempDir(), "example.go") + if err := copyTestFile(path.Join("testdata", "type.go"), inputFile); err != nil { + t.Fatal("copy test file", err) + } + outputFile := path.Join(t.TempDir(), "example.md") + config := appConfig{ + typeName: "Type1", + formatName: "markdown", + outputFileName: outputFile, + inputFileName: inputFile, + execLine: 0, + envPrefix: "TEST_", + noStyles: true, + fieldNames: true, + } + if err := run(&config); err != nil { + t.Fatal("run", err) + } + }) t.Run("bad-out", func(t *testing.T) { inputFile := path.Join(t.TempDir(), "example.go") if err := copyTestFile(path.Join("testdata", "type.go"), inputFile); err != nil { diff --git a/render.go b/render.go index 6455177..28b4754 100644 --- a/render.go +++ b/render.go @@ -69,11 +69,12 @@ func newRenderContext(scopes []*EnvScope, envPrefix string, noStyles bool) rende return res } -func newRenderItem(item EnvDocItem, envPrefix string) renderItem { +func newRenderItem(item *EnvDocItem, envPrefix string) renderItem { + log := logger() children := make([]renderItem, len(item.Children)) - debug("render item %s", item.Name) + log.Printf("render item %s", item.Name) for i, child := range item.Children { - debug("render child item %s", child.Name) + log.Printf("render child item %s", child.Name) children[i] = newRenderItem(child, envPrefix) } return renderItem{ diff --git a/render_test.go b/render_test.go index 3ec4e8e..7797518 100644 --- a/render_test.go +++ b/render_test.go @@ -177,14 +177,14 @@ func TestNewRenderContext(t *testing.T) { src := []*EnvScope{ { Name: "First", - Vars: []EnvDocItem{ + Vars: []*EnvDocItem{ { Name: "ONE", Doc: "First one", }, { Doc: "Nested", - Children: []EnvDocItem{ + Children: []*EnvDocItem{ { Name: "NESTED_ONE", Doc: "Nested one", diff --git a/types.go b/types.go index efa2965..018677d 100644 --- a/types.go +++ b/types.go @@ -9,7 +9,7 @@ type EnvDocItem struct { // Opts is a set of options for environment variable parsing. Opts EnvVarOptions // Children is a list of child environment variables. - Children []EnvDocItem + Children []*EnvDocItem debugName string // item name for debug logs. } @@ -20,7 +20,7 @@ type EnvScope struct { // Doc is a documentation text for the scope. Doc string // Vars is a list of environment variables. - Vars []EnvDocItem + Vars []*EnvDocItem } // EnvVarOptions is a set of options for environment variable parsing. diff --git a/utils.go b/utils.go index 8572f84..0c12e9c 100644 --- a/utils.go +++ b/utils.go @@ -1,7 +1,9 @@ package main import ( + "fmt" "io" + "log" "math/rand" "strings" "unicode" @@ -40,3 +42,91 @@ func fastRandString(n int) string { } return string(b) } + +type ( + envDocItemBuilderOp func(*envDocItemsBuilder) + envDocItemsBuilder struct { + envPrefix string + names []string + doc string + opts EnvVarOptions + children []*EnvDocItem + } +) + +func withEnvDocItemEnvPrefix(envPrefix string) envDocItemBuilderOp { + return func(b *envDocItemsBuilder) { + b.envPrefix = envPrefix + } +} + +func withEnvDocItemDoc(doc string) envDocItemBuilderOp { + return func(b *envDocItemsBuilder) { + b.doc = doc + } +} + +func withEnvDocItemOpts(opts EnvVarOptions) envDocItemBuilderOp { + return func(b *envDocItemsBuilder) { + b.opts = opts + } +} + +func withEnvDocItemAddChildren(children []*EnvDocItem) envDocItemBuilderOp { + return func(b *envDocItemsBuilder) { + b.children = append(b.children, children...) + } +} + +func withEnvDocItemNames(names ...string) envDocItemBuilderOp { + return func(b *envDocItemsBuilder) { + b.names = names + } +} + +var withEnvDocEmptyNames = withEnvDocItemNames("") + +func (b *envDocItemsBuilder) apply(op ...envDocItemBuilderOp) *envDocItemsBuilder { + for _, o := range op { + o(b) + } + return b +} + +func (b *envDocItemsBuilder) items() []*EnvDocItem { + items := make([]*EnvDocItem, len(b.names)) + for i, name := range b.names { + item := &EnvDocItem{ + Doc: b.doc, + Opts: b.opts, + Children: b.children, + } + if name != "" { + item.Name = fmt.Sprintf("%s%s", b.envPrefix, name) + } + items[i] = item + } + return items +} + +func (b *envDocItemsBuilder) GoString() string { + return fmt.Sprintf("envDocItemsBuilder{envPrefix: %q, names: %q, doc: %q, opts: %v, children: %v}", + b.envPrefix, b.names, b.doc, b.opts, b.children) +} + +func debugBuilder(l *log.Logger, prefix string, b *envDocItemsBuilder) { + l.Printf("%s: %s", prefix, b.GoString()) +} + +func strConcat(s ...string) string { + var b strings.Builder + var size int + for _, v := range s { + size += len(v) + } + b.Grow(size) + for _, v := range s { + b.WriteString(v) + } + return b.String() +} diff --git a/utils_test.go b/utils_test.go index fede6b5..b0cdb0d 100644 --- a/utils_test.go +++ b/utils_test.go @@ -52,3 +52,11 @@ func TestCamelToSnake(t *testing.T) { } } } + +func TestRandString(t *testing.T) { + for i := 0; i < 100; i++ { + if got := fastRandString(10); len(got) != 10 { + t.Fatalf("expected randString(10) to be 10 characters long, got %d", len(got)) + } + } +}